From 6ced3c12ff94a408bfb1fcd99b11750b67c7431e Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:00:55 -0500 Subject: [PATCH 01/96] Mark unneeded instances as "getting deprecated" (#1265) * Add a flag * Add flag check and messages --------- Co-authored-by: root --- include/ck/ck.hpp | 3 +++ ...l_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp | 5 +++++ ...ht_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp | 5 +++++ ...3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp | 5 +++++ 4 files changed, 18 insertions(+) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 0bda8b759..31dcb5f1b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -236,6 +236,9 @@ #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX +// set flag to 1 to build deprecated instances +#define CK_BUILD_DEPRECATED 1 + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp index d46be53ba..e2480db10 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp @@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ BF8, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" // 1. Default add_device_operation_instances( instances, @@ -44,6 +46,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ Empty_Tuple, NDHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); +#else +#pragma message "These instances were deprecated" +#endif } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp index 7f9493f60..36c20d5f9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ BF8, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" // 1. Default add_device_operation_instances( instances, @@ -41,6 +43,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ GKZYXC, NDHWGK, ConvBwdWeightFilter1x1Stride1Pad0>{}); +#else +#pragma message "These instances were deprecated" +#endif } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp index 4651c67a7..cc9c592d7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp @@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance PassThrough, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, @@ -48,6 +50,9 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance Empty_Tuple, NDHWGK, ConvFwd1x1S1P0>{}); +#else +#pragma message "These instances were deprecated" +#endif } } // namespace instance -- GitLab From 0f7e8ec48574b30588c70b1321acdf3dd0133b1b Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:28:19 +0200 Subject: [PATCH 02/96] Fix example CMakeLists.txt (#1267) Add proper dependency target. --- example/59_grouped_gemm_multi_ABD/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index 78f683289..e49056a94 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -1,7 +1,7 @@ add_custom_target(example_grouped_gemm_xdl_multi_abd) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) -- GitLab From f6b3f4715d19f9f22421524533d5eeb3851694bb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:44:30 -0700 Subject: [PATCH 03/96] [CI][Tests] Add a daily cron job to build CK instances for gfx9;gfx10;gfx11. (#1271) * add a daily build for instances for gfx9;gfx10;gfx11 * fix jenkins logic for instances only build * fix the path for instance_only build * reduce the number of build threads to 32 --- Jenkinsfile | 63 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ee841a180..2f449b6d8 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -534,16 +534,16 @@ def Build_CK(Map conf=[:]){ if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it //do not stash profiler on Navi or MI300 nodes - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash name: "ckProfiler.tar.gz" + sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' + stash name: "ckProfiler.tar.gz" } if (params.RUN_FULL_QA && mi300_node == 0 ){ - // build deb packages for all MI100/200/300 targets and prepare to export - sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - archiveArtifacts artifacts: 'composablekernel-tests_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash name: "ckprofiler_0.2.0_amd64.deb" + // build deb packages for all MI100/200/300 targets and prepare to export + sh 'make -j package' + archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' + archiveArtifacts artifacts: 'composablekernel-tests_*.deb' + sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' + stash name: "ckprofiler_0.2.0_amd64.deb" } } if (params.hipTensor_test && navi_node == 0 ){ @@ -660,7 +660,8 @@ def process_results(Map conf=[:]){ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION= 0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT= 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false - 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" pipeline { agent none @@ -727,6 +728,10 @@ pipeline { name: "RUN_CODEGEN_TESTS", defaultValue: true, description: "Run the codegen tests (default: ON)") + booleanParam( + name: "BUILD_INSTANCES_ONLY", + defaultValue: false, + description: "Test building instances for various architectures simultaneously (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -824,7 +829,7 @@ pipeline { -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx908;gfx90a" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check""" + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -848,12 +853,12 @@ pipeline { setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -868,12 +873,12 @@ pipeline { } agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -884,27 +889,47 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908;gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } + stage("Build CK instances for different targets") + { + when { + beforeAgent true + expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \ + -D INSTANCES_ONLY=ON \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """ + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } stage("Build CK and run Tests on Navi21") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("navi21") } environment{ @@ -924,7 +949,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("navi32") } environment{ -- GitLab From 43579900a9df6d632f045d0c189ea518dac03aa0 Mon Sep 17 00:00:00 2001 From: Sam Wu <22262939+samjwu@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:44:59 -0600 Subject: [PATCH 04/96] Update documentation requirements and configurations (#1272) * Update documentation requirements Set rocm-docs-core to v1.1.1 * Update RTD config Set Python 3.10 for rocm-docs-core >= v1.0.0 --- .readthedocs.yaml | 2 +- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 16 ++-------------- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9e6678abe..b3299fa4e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,4 +15,4 @@ python: build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index a85454243..dc1824931 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.38.1 +rocm-docs-core==1.1.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 801726ed6..9a451d970 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile requirements.in @@ -48,12 +48,6 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 - # via - # sphinx - # sphinxcontrib-bibtex -importlib-resources==6.1.0 - # via rocm-docs-core jinja2==3.1.2 # via # myst-parser @@ -99,8 +93,6 @@ pyjwt[crypto]==2.6.0 # via pygithub pynacl==1.5.0 # via pygithub -pytz==2023.3.post1 - # via babel pyyaml==6.0 # via # myst-parser @@ -111,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.38.1 +rocm-docs-core==1.1.1 # via -r requirements.in six==1.16.0 # via @@ -165,7 +157,3 @@ urllib3==1.26.18 # via requests wrapt==1.15.0 # via deprecated -zipp==3.17.0 - # via - # importlib-metadata - # importlib-resources -- GitLab From a2d0bdd5a9fe6d38367fe48cce88bbab28a3baf0 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 30 Apr 2024 22:45:22 -0500 Subject: [PATCH 05/96] Add an ignore (#1270) --- ...a_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp | 1 + ...weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp | 1 + ...conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp | 1 + 3 files changed, 3 insertions(+) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp index e2480db10..3f191ab6b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp @@ -48,6 +48,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ ConvBwdDataFilter1x1Stride1Pad0>{}); #else #pragma message "These instances were deprecated" + std::ignore = instances; #endif } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp index 36c20d5f9..6e7f22b7e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -45,6 +45,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ ConvBwdWeightFilter1x1Stride1Pad0>{}); #else #pragma message "These instances were deprecated" + std::ignore = instances; #endif } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp index cc9c592d7..7b5ddf0a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp @@ -52,6 +52,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance ConvFwd1x1S1P0>{}); #else #pragma message "These instances were deprecated" + std::ignore = instances; #endif } -- GitLab From f0bf1e31255bf4d674dd3ca641ec7f1d46141e22 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 1 May 2024 10:07:14 -0700 Subject: [PATCH 06/96] [CI] Focus CI stages on MI200 nodes for resource optimization (#1273) --- Jenkinsfile | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 2f449b6d8..d334549bb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -814,21 +814,21 @@ pipeline { { parallel { - stage("Run Codegen Tests on MI100/MI200") + stage("Run Codegen Tests on MI200") { when { beforeAgent true expression { params.RUN_CODEGEN_TESTS.toBoolean() } } options { retry(2) } - agent{ label rocmnode("gfx908 || gfx90a")} + agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx908;gfx90a" \ + -D GPU_TARGETS="gfx90a" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" } steps{ @@ -842,13 +842,13 @@ pipeline { { parallel { - stage("Build CK and run Tests on MI100/MI200/MI300") + stage("Build CK for all gfx9 targets") { when { beforeAgent true expression { params.RUN_FULL_QA.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ @@ -885,13 +885,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on MI100/MI200") + stage("Build CK and run Tests on MI200") { when { beforeAgent true expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -972,27 +972,11 @@ pipeline { { parallel { - stage("Run ckProfiler: gfx90*") - { - when { - beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } - } - options { retry(2) } - agent{ label rocmnode("gfx908 || gfx90a")} - environment{ - setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """ - } - steps{ - runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') - cleanWs() - } - } stage("Run ckProfiler: gfx90a") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx90a")} -- GitLab From 7797f7c7a101bc2d50e815ce39c4f75046843131 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 1 May 2024 15:34:56 -0700 Subject: [PATCH 07/96] Downgrade minimum required python version to 3.6 (#1274) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3a9801cc..e3113a31d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX) include(CTest) -find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) +find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") -- GitLab From 08d51d9bc4ec275fce3ad0a01a08ab1fd45636bc Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 2 May 2024 11:27:59 -0700 Subject: [PATCH 08/96] add missing vector header (#1275) --- include/ck/host_utility/flush_cache.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 805fb571f..f490edb67 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "ck/ck.hpp" #include "ck/stream_config.hpp" -- GitLab From 6d073d31bbc7d39d8b170d549f2af61970378150 Mon Sep 17 00:00:00 2001 From: Sam Wu <22262939+samjwu@users.noreply.github.com> Date: Mon, 6 May 2024 10:07:39 -0600 Subject: [PATCH 09/96] Add ROCm Doc team as codeowners for RTD yaml (#1277) Also add component owners as codeowners for header directory --- .github/CODEOWNERS | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 01e3bee0b..bc49ac166 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,8 @@ * @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex # Documentation files -docs/* @ROCm/rocm-documentation -*.md @ROCm/rocm-documentation -*.rst @ROCm/rocm-documentation +docs/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +*.md @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +*.rst @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +.readthedocs.yaml @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex # Header directory for Doxygen documentation -library/include/* @ROCm/rocm-documentation +library/include/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex -- GitLab From 851c3ed1576fed2593b824fb83b3f9349f6aeb09 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 7 May 2024 22:32:54 +0800 Subject: [PATCH 10/96] [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- example/ck_tile/01_fmha/README.md | 23 +- example/ck_tile/01_fmha/bias.hpp | 100 ++++++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 116 ++++++++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 11 +- example/ck_tile/01_fmha/generate.py | 36 ++- example/ck_tile/01_fmha/mask.hpp | 14 +- example/ck_tile/01_fmha/script/smoke_test.sh | 7 +- include/ck_tile/core/config.hpp | 5 + include/ck_tile/core/numeric/math.hpp | 11 + include/ck_tile/ops/fmha.hpp | 2 + .../fmha/block/block_attention_bias_enum.hpp | 37 +++ .../fmha/block/block_position_encoding.hpp | 189 +++++++++++++++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 85 +++++-- .../pipeline/block_fmha_pipeline_enum.hpp | 19 ++ .../pipeline/block_fmha_pipeline_problem.hpp | 2 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 51 ++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 58 +++-- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 22 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 51 ++++- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 5 +- test/CMakeLists.txt | 1 + test/position_embedding/CMakeLists.txt | 1 + .../position_embedding/position_embedding.cpp | 215 ++++++++++++++++++ 24 files changed, 948 insertions(+), 115 deletions(-) create mode 100644 example/ck_tile/01_fmha/bias.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_position_encoding.hpp create mode 100644 test/position_embedding/CMakeLists.txt create mode 100644 test/position_embedding/position_embedding.cpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e31c96caa..85d25c63d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -17,7 +17,7 @@ add_custom_command( set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding tile_example ${EXAMPLE_NAME}") +message("adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 5a428e4d4..fd5690a79 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -30,27 +30,29 @@ args: -mode kernel mode. 0:batch, 1:group (default:0) -b batch size (default:2) -h num of head, for q (default:8) - -h_k num of head, for k/v, 0 means equal to h (default:0) + -h_k num of head, for k/v, -1 means equal to h (default:-1) if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary - -s_k seqlen_k, 0 means equal to s (default:0) + -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) - -d_v head dim for v, 0 means equal to d (default:0) + -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) note when squant=1, this value will be modified by range_q/k - -range_q per-tensor quantization range of q. used if squant=1. (default:2) - -range_k per-tensor quantization range of k. used if squant=1. (default:2) - -range_v per-tensor quantization range of v. used if squant=1. (default:2) + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) - -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) - -bias add bias or not (default:0) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) 't', top-left causal mask, 'b', bottom-r causal mask @@ -59,11 +61,11 @@ args: 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) - -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. @@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov ### attention bias Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. +### alibi +alibi is supported + ### lse For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp new file mode 100644 index 000000000..f9dc656f6 --- /dev/null +++ b/example/ck_tile/01_fmha/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8ca4ff933..686633bb2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[]) .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") .insert("h_k", - "0", - "num of head, for k/v, 0 means equal to h\n" + "-1", + "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") .insert("s", "3328", "seqlen_q. if group-mode, means the average value of seqlen_q\n" "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") - .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("d", "128", "head dim for q, k") - .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" @@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[]) "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("mask", "0", @@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - if(nhead_k == 0) + if(nhead_k < 0) nhead_k = nhead; if(nhead % nhead_k != 0) @@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t seqlen_q = arg_parser.get_int("s"); ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k == 0) + if(seqlen_k < 0) seqlen_k = seqlen_q; ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v == 0) + if(hdim_v < 0) hdim_v = hdim_q; bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim @@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::string vlayout = arg_parser.get_str("vlayout"); - bool use_bias = arg_parser.get_bool("bias"); bool lse = arg_parser.get_bool("lse"); + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); int init_method = arg_parser.get_int("init"); @@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor v_host( is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); - // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host - // will not be used for verification at all (but will be copied to device anyway). + ck_tile::HostTensor bias_host( - use_bias + bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{shape_batch, nhead, shape_seqlen_q} @@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser) // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off auto layout_str = [&](bool permute){ @@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s - << ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout + << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mode == mode_enum::group, is_v_rowmajor, mask.type, - use_bias, + bias.type, lse, squant}; @@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_q, stride_k, stride_v, - stride_bias, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, stride_o, nhead_stride_q, nhead_stride_k, @@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); - if(use_bias) + if(bias.type == bias_enum::elementwise_bias) { + // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) @@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser) SMPLComputeDataType>( s_host_ref, bias_host_ref, s_host_ref); } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } if(mask.type == mask_enum::no_mask) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9a82ab6b7..fb3907fec 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include "bias.hpp" #include template @@ -86,7 +87,7 @@ struct fmha_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* bias_ptr; + const void* bias_ptr; // bias or alibi_slope pointer void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -106,7 +107,7 @@ struct fmha_fwd_args ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; @@ -219,7 +220,7 @@ template ; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; @@ -261,7 +262,7 @@ struct fmha_fwd_traits bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool has_bias; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; // TODO: padding check is inside this api diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 56d699e5f..51fecd07b 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -40,6 +40,19 @@ MASK_MAP = { "generic" : "FmhaMasks::GenericMask" } +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + MODE_MAP = { "batch" : "false", "group" : "true" @@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = { "s_mask" : "t.mask_type != mask_enum::no_mask", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); @@ -213,7 +226,7 @@ class FmhaFwdApiTrait: bk0blen : int vlayout : str mask : str - bias : str # true/false + bias : str # lse : str # squant : str # spad : str @@ -241,8 +254,8 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qr_fp8']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' @@ -297,7 +310,7 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias == 't' : n += '_bias' + if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: @@ -332,7 +345,8 @@ class FmhaFwdApiPool: if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -400,7 +414,7 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, @@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ } elif dtype == 'fp8' or dtype == 'bf8': return { - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) } else: return None @@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): if hdim == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) @@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) else: assert False diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 56fc8b8b1..c77b700b1 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -149,11 +149,9 @@ struct mask_info return tmp; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } }; - -inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) -{ - mi.serialize(os); - return os; -} diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 4dd5c2ae1..2c4bb562a 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -17,7 +17,7 @@ for perm in 0 1 ; do for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do -for bias in 0 1 ; do +for bias in "n" "e" "a"; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS @@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done @@ -37,9 +38,11 @@ done done for perm in 0 1 ; do -for bias in 0 1 ; do +for bias in "n" "e" "a" ; do for b in 1 2 ; do +for hdim in 64 128 256 ; do $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS done done done +done diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index d915df6e4..82b6953b5 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -154,3 +154,8 @@ #ifndef CK_TILE_USE_SUBDWORD_TILE_CAST #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif + +// TODO: better solve this inside compiler +#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 +#define CK_TILE_FMHA_FWD_FAST_EXP2 0 +#endif diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 72ec607b4..d4984363d 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -536,4 +536,15 @@ float log(float x) { return __logf(x); }; CK_TILE_HOST float log(float x) { return std::logf(x); }; +CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + // TODO: this is hacky, we use u16 + return __builtin_amdgcn_sad_u16(x, y, acc); +} + +CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + return (x > y ? (x - y) : (y - x)) + acc; +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index c567e63dd..1122bf87b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,7 +3,9 @@ #pragma once +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp new file mode 100644 index 000000000..e5be21e04 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionBiasEnum +{ + NO_BIAS = 0, + ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) + ALIBI = 2, // bias computed with position encoding, applied after scale +}; + +template +struct BlockAttentionBiasEnumToStr; + +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "bias"; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "alibi"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp new file mode 100644 index 000000000..9c6c35390 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include +#include + +namespace ck_tile { + +enum struct PositionEncodingEnum +{ + NO = 0, + ALIBI = 1, +}; + +/* +VERTICAL: + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + +TOP_LEFT: + [0] 1 2 3 4 5 + 1 [0] 1 2 3 4 + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + +FROM_BOTTOM_RIGHT: + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + 4 3 2 1 [0] 1 + 5 4 3 2 1 [0] +*/ + +enum struct AlibiMode +{ + VERTICAL = 0, + FROM_TOP_LEFT = 1, // keep sync with mask enum + FROM_BOTTOM_RIGHT = 2, +}; + +template +struct Alibi +{ + // RowMajor here means if pixel within the same thread are along the row, or col + // this may impact the performance of update(), while the result are the same. + // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false + CK_TILE_HOST_DEVICE Alibi(DataType slope_, + index_t y_total_, + index_t x_total_, + AlibiMode mode_ = AlibiMode::VERTICAL) + { + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; + + shift_left_up = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + }(); + shift_right_down = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + }(); + mode = mode_; + } + + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) + { + if constexpr(RowMajor) + { + // at least 3 instructions per row + index_t current_zero_point = + mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; + + // for every threads, most of the pixels are along the row, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(col_idx + shift_left_up), + 0)); + pixel += slope * position; + } + else + { + // at least 3 instructions per col; + index_t current_zero_point = mode == AlibiMode::VERTICAL + ? row_idx + col_idx + shift_right_down + : col_idx + shift_right_down; + + // for every threads, most of the pixels are along the col, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(row_idx + shift_left_up), + 0)); + pixel += slope * position; + } + } + + DataType slope; // float? + index_t shift_left_up; // always possitive + index_t shift_right_down; // always possitive + AlibiMode mode; +}; + +template +struct EmptyPositionEncoding +{ + CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) + { + } +}; + +// +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +template +CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, + index_t window_left_size, + index_t window_right_size, + index_t y_total, + index_t x_total, + GenericAttentionMaskEnum mask_enum) +{ + // assume mask_enum will never be NO_MASK, since if we do not have mask, it's + // totally OK to use constexpr + bool is_causal = window_left_size < 0 && window_right_size == 0; + AlibiMode alibi_mode = + is_causal ? AlibiMode::VERTICAL + : static_cast(mask_enum) /*either top-left or bottom-right*/; + return Alibi{slope, y_total, x_total, alibi_mode}; +} + +// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +// Do we need a device version? +template +CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) +{ + auto get_slopes_power_of_2 = [](ck_tile::index_t n) { + float start = std::powf( + static_cast(2), + -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); + + std::vector rtn; + for(auto i = 0; i < n; i++) + { + rtn.push_back(static_cast(start * std::powf(start, i))); + } + return rtn; + }; + if(is_power_of_two_integer(nheads)) + { + // power of 2 calculation + return get_slopes_power_of_2(nheads); + } + else + { + ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); + auto v0 = get_slopes_power_of_2(closest_power_of_2); + auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); + auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { + std::vector sliced; + for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) + { + if(i % 2 == 0) + sliced.push_back(vec[i]); + } + std::vector sliced_2(sliced.begin(), sliced.begin() + rem); + return sliced_2; + }(v1, nheads - closest_power_of_2); + v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); + return v0; + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0732fd2ce..10ce7395a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include #include @@ -33,6 +34,7 @@ struct FmhaFwdKernel using BiasDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; @@ -41,7 +43,7 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; @@ -81,7 +83,8 @@ struct FmhaFwdKernel "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -136,6 +139,13 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_bias = 0; }; + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -162,7 +172,11 @@ struct FmhaFwdKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -175,7 +189,11 @@ struct FmhaFwdKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -255,13 +273,18 @@ struct FmhaFwdKernel batch_stride_v, batch_stride_o}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -345,12 +368,17 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -421,14 +449,10 @@ struct FmhaFwdKernel { batch_offset_v = key_start; } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; } - else - { - batch_offset_bias = key_start; - } if constexpr(kStoreLSE) { batch_offset_lse = query_start; @@ -461,7 +485,7 @@ struct FmhaFwdKernel batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } @@ -585,7 +609,7 @@ struct FmhaFwdKernel const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto bias_dram_window_lengths = make_tuple(number{}, number{}); - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { const BiasDataType* bias_ptr = reinterpret_cast(kargs.bias_ptr) + @@ -654,6 +678,39 @@ struct FmhaFwdKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { @@ -672,6 +729,7 @@ struct FmhaFwdKernel scales{kargs.scale_p}, // p_compute_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func mask, + position_encoding, kargs.scale_s, smem_ptr); } @@ -683,6 +741,7 @@ struct FmhaFwdKernel bias_dram_window, lse_dram_window, mask, + position_encoding, kargs.scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 550017408..cf70dff63 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum QSKSVS, }; +template +struct BlockFmhaPipelineEnumToStr; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qs"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9d27b2df6..159fb4074 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasBias = Traits::kHasBias; + static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e239bb91..60650761d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 0573b50d0..8a19deb02 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync { if constexpr(kK0BlockLength <= 32) { - if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) return 1; else return 2; } else if constexpr(kK0BlockLength <= 64) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 2; else return 3; } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) { if(num_total_loop <= 0) { @@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); } @@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 0e59ee6fe..80f40f815 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, + PositionEncoding /*position_encoding*/, float scale_s, float descale_qk, float descale_sv, @@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 677c05769..e12e76706 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" namespace ck_tile { @@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr index_t kBlockPerCu = []() { @@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 137f4ddd8..6cb6449f1 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" namespace ck_tile { @@ -11,7 +12,7 @@ template @@ -21,7 +22,7 @@ struct TileFmhaTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 33aa10df7..25c63ac7f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -181,3 +181,4 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() +add_subdirectory(position_embedding) diff --git a/test/position_embedding/CMakeLists.txt b/test/position_embedding/CMakeLists.txt new file mode 100644 index 000000000..e7a939beb --- /dev/null +++ b/test/position_embedding/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_position_embedding position_embedding.cpp) diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp new file mode 100644 index 000000000..e295ec454 --- /dev/null +++ b/test/position_embedding/position_embedding.cpp @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +#ifndef TEST_ALIBI_VERBOSE +#define TEST_ALIBI_VERBOSE 0 +#endif + +template +struct attention_score +{ + ck_tile::index_t rows, cols; + std::vector pixels; + + attention_score(ck_tile::index_t rows_, + ck_tile::index_t cols_, + DataType init_v_ = static_cast(0)) + : rows(rows_), cols(cols_), pixels(rows_ * cols_, init_v_) + { + } + + auto& operator()(ck_tile::index_t i_row, ck_tile::index_t i_col) + { + return pixels[i_row * cols + i_col]; + } + + void print() + { + for(auto i_row = 0; i_row < rows; i_row++) + { + for(auto i_col = 0; i_col < cols; i_col++) + { + std::cout << pixels[i_row * cols + i_col] << " "; + } + std::cout << std::endl; + } + } +}; + +template +void alibi_traverse_with_slope(attention_score& score, + DataType slope, + ck_tile::AlibiMode mode = ck_tile::AlibiMode::VERTICAL) +{ + using Alibi = ck_tile::Alibi; + auto alibi = Alibi{slope, score.rows, score.cols, mode}; + + for(ck_tile::index_t i_row = 0; i_row < score.rows; i_row++) + { + for(ck_tile::index_t i_col = 0; i_col < score.cols; i_col++) + { + alibi.update(score(i_row, i_col), i_row, i_col); + } + } +} + +std::string alibi_mode_to_str(ck_tile::AlibiMode mode) +{ + if(mode == ck_tile::AlibiMode::VERTICAL) + return std::string("alibi_verti"); + else if(mode == ck_tile::AlibiMode::FROM_TOP_LEFT) + return std::string("alibi_top-l"); + else if(mode == ck_tile::AlibiMode::FROM_BOTTOM_RIGHT) + return std::string("alibi_bot-r"); + return ""; +} + +template +bool test_alibi_traverse_with_slope(ck_tile::index_t rows, + ck_tile::index_t cols, + DataType slope, + ck_tile::AlibiMode mode, + const std::vector& expected) +{ + attention_score score{rows, cols}; + alibi_traverse_with_slope(score, slope, mode); + + bool is_match = std::equal(score.pixels.begin(), score.pixels.end(), expected.begin()); +#if TEST_ALIBI_VERBOSE + std::cout << "---------" << alibi_mode_to_str(mode) << ", " << rows << "x" << cols << "(" + << (RowMajor ? "row_major" : "col_major") << ")" + << (is_match ? ", valie:y" : ", valid:n") << std::endl; + score.print(); +#endif + return is_match; +} + +template +bool test_alibi_slope_generation(ck_tile::index_t nheads, const std::vector& expected) +{ + auto slopes = ck_tile::get_alibi_slopes(nheads); + + bool is_match = std::equal(slopes.begin(), + slopes.end(), + expected.begin(), + expected.end(), + [](const DataType& lhs, const DataType& rhs) { + constexpr float rtol = 1e-6; + auto error = std::abs(lhs - rhs); + return error < rtol * std::abs(rhs); + }); +#if TEST_ALIBI_VERBOSE + std::cout << "-------------------- slopes " << nheads << ", " << (is_match ? "y" : "n") + << std::endl; + for(ck_tile::index_t i = 0; i < nheads; i++) + { + std::cout << slopes[i] << " "; + } + std::cout << std::endl; +#endif + return is_match; +} + +int main() +{ + using dtype = int32_t; + dtype slope = static_cast(1); + + bool rtn = true; + + // clang-format off + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, + 1, 0, 1, 2, 3, 4, + 2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0, + 4, 3, 2, 1, + 5, 4, 3, 2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2, + 4, 3, 2, 1, 0, 1, + 5, 4, 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, + 1, 0, 1, 2, 3, 4, + 2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0, + 4, 3, 2, 1, + 5, 4, 3, 2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2, + 4, 3, 2, 1, 0, 1, + 5, 4, 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); + rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, + 0.12500000000000006, 0.08838834764831849, 0.06250000000000004, 0.044194173824159244, + 0.03125000000000002, 0.022097086912079626, 0.01562500000000001, 0.011048543456039816, + 0.007812500000000007, 0.005524271728019908, 0.003906250000000004}); + rtn &= test_alibi_slope_generation(1, {0.00390625}); + rtn &= test_alibi_slope_generation(5, {0.25, 0.0625, 0.015625, 0.00390625, 0.5}); + rtn &= test_alibi_slope_generation(6, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125}); + rtn &= test_alibi_slope_generation(7, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125, 0.03125}); + rtn &= test_alibi_slope_generation(9, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625, 0.7071067811865476}); + // clang-format on + return rtn ? 0 : -1; +} -- GitLab From bf42097646884e9917b69cae1c150fc12f697a4c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 7 May 2024 16:26:43 -0700 Subject: [PATCH 11/96] Enable logging in CK with environment variable. (#1278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * enable logging using environment variable * update ck.hpp header * fix typo * fix clang format * Update include/ck/utility/env.hpp Co-authored-by: Bartłomiej Kocot --------- Co-authored-by: Bartłomiej Kocot --- include/ck/ck.hpp | 10 +- include/ck/host_utility/flush_cache.hpp | 45 +++-- include/ck/host_utility/kernel_launch.hpp | 64 +++--- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 49 ++--- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 31 +-- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 7 +- ...ice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 37 ++-- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 5 +- ...shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 5 +- ...onv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 5 +- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 4 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 3 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 3 +- ...device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 3 +- .../gpu/device/impl/device_gemm_dl.hpp | 3 +- .../impl/device_gemm_reduce_xdl_cshuffle.hpp | 3 +- .../device_gemm_xdl_layernorm_cshuffle.hpp | 3 +- .../impl/device_gemm_xdl_skip_b_lds.hpp | 3 +- .../device_grouped_gemm_multiple_d_dl.hpp | 41 ++-- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 60 +++--- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 12 +- .../device/impl/device_grouped_gemm_xdl.hpp | 45 ++--- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 22 ++- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 136 ++++++------- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 126 ++++++------ .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 138 ++++++------- include/ck/utility/env.hpp | 185 ++++++++++++++++++ .../profile_grouped_gemm_fixed_nk_impl.hpp | 11 +- .../profiler/profile_grouped_gemm_impl.hpp | 11 +- .../profile_grouped_gemm_tile_loop_impl.hpp | 11 +- .../profile_grouped_gemm_two_stage_impl.hpp | 11 +- 31 files changed, 650 insertions(+), 442 deletions(-) create mode 100644 include/ck/utility/env.hpp diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 31dcb5f1b..c8025f53c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,12 +4,19 @@ #pragma once #include "ck/config.h" +#include "ck/utility/env.hpp" #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" #endif +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + +// to do: add various levels of logging with CK_LOG_LEVEL + #define CK_TIME_KERNEL 1 // constant address space for kernel parameter @@ -225,9 +232,6 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// flag to enable (1) or disable (0) the debugging output in some kernels -#define DEBUG_LOG 0 - // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index f490edb67..a93853c34 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -117,18 +117,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #define MEDIAN 1 if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { @@ -141,9 +142,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { return 0.0; } -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } #if MEDIAN std::set times; @@ -184,13 +186,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, total_time += cur_time; #endif -#if DEBUG_LOG - std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; - printf("args.p_a_grid: %p, args.p_b_grid:%p\n", - static_cast(args.p_a_grid), - static_cast(args.p_b_grid)); -#endif + printf("args.p_a_grid: %p, args.p_b_grid:%p\n", + static_cast(args.p_a_grid), + static_cast(args.p_b_grid)); + } } #if MEDIAN diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 1ed7686e7..df85f06c7 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,18 +20,19 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { @@ -40,9 +41,10 @@ float launch_and_time_kernel(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); @@ -93,18 +95,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up preprocess(); for(int i = 0; i < stream_config.cold_niters_; ++i) @@ -114,9 +117,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 8f533ef62..4521b2161 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle BatchStrideD1s, BatchStrideE1} { -#if DEBUG_LOG - std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " - << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; - std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " - << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; - std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", " - << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; - std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " - << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; - std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " - << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" - << std::endl; - std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " - << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " + << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; + std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " + << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) + << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; + std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " + << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" + << std::endl; + std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " + << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } static_for<0, NumD0Tensor, 1>{}([&](auto i) { using D0Layout = remove_cvref_t>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index d491ee2ea..37ebe2f85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { - std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; + { + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; - std::cout << "arg.a_grid_desc_ak0_m_ak1_{" - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_bk0_n_bk1_{" - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) - << "}" << std::endl; + std::cout << "arg.reduce_grid_desc_m_{ " + << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl; + } } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 00a89c47b..445467be5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DEBUG_LOG - arg.Print(); -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + arg.Print(); + } if(!ck::is_xdl_supported()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index e22c5a2aa..6fd8c0323 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { - std::cout << "arg.a_grid_desc_k0_m_k1_container_{" - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" - << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_container_{" - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" - << std::endl; - - std::cout << "arg.c_grid_desc_m_n_container_{ " - << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " - << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" - << std::endl; + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + } } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index c9e8940ed..f5c1460f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -644,7 +644,7 @@ struct float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -664,9 +664,7 @@ struct << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -684,7 +682,6 @@ struct std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 28fceb428..9015f640a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index ca291d3b1..e815c0784 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " @@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W << arg.input_left_pads_[1] << ", " << std::endl; std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " << arg.input_right_pads_[1] << ", " << std::endl; - } - { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; @@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W .GetLength(I5) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index ef94120f4..760e2840d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " @@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif + if(!GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 55cf8df27..de4871939 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index d95671be7..5d9f8a178 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " @@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) << " ) " << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index ee3f0cea1..439872455 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " @@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index bac124a2f..515892142 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { -#if DEBUG_LOG - std::cout << "The group count is not equal to sum of skipped groups " - "and kernel args size!" - << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } return false; } @@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { -#if DEBUG_LOG - std::cout << "[" << __func__ << "] group id: " << i - << " has invalid GridwiseGemm settings!" << std::endl; - gemm_arg.Print(); -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); + } } supported = supported && group_arg_valid; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 0a0e8072b..7c252092a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -620,11 +620,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop GridwiseGemm::template CheckTensorTransfersValidity( M, N, K))) { -#if DEBUG_LOG - std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K - << "] are not supported by current template parameters!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; -#endif + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," + << K << "] are not supported by current template parameters!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + } supported = false; } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 7dfb677ec..90c0593b2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { -#if DEBUG_LOG - std::cout << "The group count is not equal to sum of skipped groups " - "and kernel args size!" - << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } return false; } @@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK, bhalf_t>::value) { -#if DEBUG_LOG - std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } if(karg.KBatch > 1) { return false; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index ab3449b1c..fdafa9ca5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1113,12 +1113,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1130,12 +1130,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1149,12 +1149,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { -#if DEBUG_LOG - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1173,13 +1173,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1187,13 +1187,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1202,13 +1202,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1216,13 +1216,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1231,14 +1231,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1246,14 +1247,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index b52f5c51b..f2eeaf7e3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -446,12 +446,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.M % MPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -463,12 +463,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.N % NPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -482,12 +482,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto K_t = karg.k_batch * K0PerBlock * K1; if(!(karg.K % K_t == 0)) { -#if DEBUG_LOG - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -496,13 +496,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -510,13 +510,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -525,13 +525,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -539,13 +539,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -554,14 +554,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -569,14 +569,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -584,12 +584,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 const auto num_k_loop = karg.K0Padded / K0PerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { -#if DEBUG_LOG - std::cout << "The number of k loops (" << num_k_loop - << ") value is not supported by GridwiseGemm Pipeline." - << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp new file mode 100644 index 000000000..0b6504e52 --- /dev/null +++ b/include/ck/utility/env.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace internal { +template +struct ParseEnvVal +{ +}; + +template <> +struct ParseEnvVal +{ + static bool parse_env_var_value(const char* vp) + { + std::string value_env_str{vp}; + + for(auto& c : value_env_str) + { + if(std::isalpha(c) != 0) + { + c = std::tolower(static_cast(c)); + } + } + + if(value_env_str == "disable" || value_env_str == "disabled" || value_env_str == "0" || + value_env_str == "no" || value_env_str == "off" || value_env_str == "false") + { + return false; + } + else if(value_env_str == "enable" || value_env_str == "enabled" || value_env_str == "1" || + value_env_str == "yes" || value_env_str == "on" || value_env_str == "true") + { + return true; + } + else + { + throw std::runtime_error("Invalid value for env variable"); + } + + return false; // shouldn't reach here + } +}; + +// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default). +// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string). +template <> +struct ParseEnvVal +{ + static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); } +}; + +template <> +struct ParseEnvVal +{ + static std::string parse_env_var_value(const char* vp) { return std::string{vp}; } +}; + +template +struct EnvVar +{ + private: + T value{}; + bool is_unset = true; + + public: + const T& GetValue() const { return value; } + + bool IsUnset() const { return is_unset; } + + void Unset() { is_unset = true; } + + void UpdateValue(const T& val) + { + is_unset = false; + value = val; + } + + explicit EnvVar(const char* const name, const T& def_val) + { + // NOLINTNEXTLINE (concurrency-mt-unsafe) + const char* vp = std::getenv(name); + if(vp != nullptr) // a value was provided + { + is_unset = false; + value = ParseEnvVal::parse_env_var_value(vp); + } + else // no value provided, use default value + { + value = def_val; + } + } +}; +} // end namespace internal + +// static inside function hides the variable and provides +// thread-safety/locking +// Used in global namespace +#define CK_DECLARE_ENV_VAR(name, type, default_val) \ + namespace ck::env { \ + struct name \ + { \ + static_assert(std::is_same_v, \ + "CK_DECLARE_ENV* must be used in the global namespace"); \ + using value_type = type; \ + static ck::internal::EnvVar& Ref() \ + { \ + static ck::internal::EnvVar var{#name, default_val}; \ + return var; \ + } \ + }; \ + } + +#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false) + +#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0) + +#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") + +#define ENV(name) \ + ck::env::name {} + +template +inline const std::string& EnvGetString(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsEnabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsDisabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue(); +} + +template +inline uint64_t EnvValue(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsUnset(EnvVar) +{ + return EnvVar::Ref().IsUnset(); +} + +template +void EnvUnset(EnvVar) +{ + EnvVar::Ref().Unset(); +} + +/// updates the cached value of an environment variable +template +void UpdateEnvVar(EnvVar, const ValueType& val) +{ + static_assert(std::is_same_v); + EnvVar::Ref().UpdateValue(val); +} + +template +void UpdateEnvVar(EnvVar, const std::string_view& val) +{ + EnvVar::Ref().UpdateValue( + ck::internal::ParseEnvVal::parse_env_var_value(val.data())); +} + +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 67fba43d6..80c1c42b8 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -88,11 +88,12 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 7f48ee069..476ec37eb 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -87,11 +87,12 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 3d7fa4707..33e758f40 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -82,11 +82,12 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } switch(init_method) { case 0: break; diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp index 41dcabbfc..feb0be87e 100644 --- a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -88,11 +88,12 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { -- GitLab From 0b6b5d1785c8da9b857456a727cb16d5905271ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 8 May 2024 09:53:24 +0200 Subject: [PATCH 12/96] Add two stage grouped conv bwd weight kernel (#1280) --- .../11_grouped_conv_bwd_weight/common.hpp | 6 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 898 ++++++++++++++++++ ...conv_bwd_weight_two_stage_xdl_instance.hpp | 52 + .../grouped_convolution_backward_weight.hpp | 4 + ...rouped_convolution_backward_weight_xdl.inc | 24 + .../grouped_conv2d_bwd_weight/CMakeLists.txt | 3 +- ...age_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 48 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 3 +- ..._xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 48 + .../profile_grouped_conv_bwd_weight_impl.hpp | 4 + 10 files changed, 1087 insertions(+), 3 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index 1a36490ef..541a0a19a 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -160,6 +160,10 @@ bool run_grouped_conv_bwd_weight( auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 000000000..d30252e68 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,898 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_bwd_weight( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = batch_count; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemm{}; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, + BDataType, + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + element_wise::PassThrough, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; + + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwise = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_e_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + ce_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_); + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; + + constexpr bool has_main_loop = has_main_k_block_loop.value; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + AccDataType, + OutElementwiseOperation, + InElementwiseOperation, + element_wise::PassThrough, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * + arg.Conv_G_; + + std::array in_out_batch_strides = { + arg.compute_ptr_offset_of_batch_.BatchStrideC_}; + + const auto kernel = kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_grid_desc_m_n_), + make_tuple(arg.ce_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + }; + + float avg_time = 0; + if(has_main_k0_block_loop) + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + else + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + std::cerr << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWK_GKXC_GNWC()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGK_GKYXC_NHWGC() || + is_GNHWK_GKYXC_GNHWC())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC() || + is_GNDHWK_GKZYXC_GNDHWC())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp new file mode 100644 index 000000000..8120eff25 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index dc56b8f4b..91b7df3d4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -192,6 +204,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 340ddfb3f..a21b7702b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -5,7 +5,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV2D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 000000000..ef583cf4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 8b89dcf7e..435d1831e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -5,7 +5,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..c4849c017 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 5b981dda3..356aec7a0 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -188,6 +188,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, out_element_op, split_k); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // using atomic add, so need to reset input -- GitLab From fdbf8ccbd75fd0b255515b6f7cd385f47f9a22a8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 8 May 2024 16:11:54 -0700 Subject: [PATCH 13/96] fix the output formatting (#1282) --- include/ck/host_utility/flush_cache.hpp | 2 +- include/ck/host_utility/kernel_launch.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index a93853c34..36993d0ae 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -119,7 +119,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, grid_dim.x, grid_dim.y, diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index df85f06c7..1cdb7f9c5 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -22,7 +22,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, { if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, grid_dim.x, grid_dim.y, @@ -97,7 +97,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { if(ck::EnvIsEnabled(ENV(CK_LOGGING))) { - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, grid_dim.x, grid_dim.y, -- GitLab From 3c043cd10b48c3a67c9ee4472336603bca08faec Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 9 May 2024 16:30:17 +0200 Subject: [PATCH 14/96] Add vector instruction coherency bits for gfx94 targets. (#1268) --- include/ck/utility/amd_buffer_addressing.hpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 678c55b95..cfa4cabee 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -297,6 +297,17 @@ enum struct AmdBufferCoherenceEnum GLC = 1, SLC = 2, GLC_SLC = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template -- GitLab From a0ae1c61334f82ee62850d3356a9977b85b5ee2b Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Thu, 9 May 2024 18:42:41 +0200 Subject: [PATCH 15/96] Fix MakeArgument (#1284) --- example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp | 3 ++- .../device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index d80c163e3..2b891dd6f 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -92,9 +92,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto group_count = problem_size.group_count; using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using GemmDesc = ck::tensor_operation::device::GemmDesc; // GEMM shape - std::vector gemm_descs; + std::vector gemm_descs; std::vector ggemm_kargs; std::vector p_Cs; std::vector p_As; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 7c252092a..403bc7fad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -375,7 +375,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop std::vector& /* p_Bs */, std::vector>& /* p_Ds */, std::vector& /* p_Es */, - std::vector& gemm_descs, + const std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, @@ -643,7 +643,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop std::vector& p_Bs, std::vector>& p_Ds, std::vector& p_Es, - std::vector gemm_descs, + std::vector& gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) -- GitLab From 8346af9c686649703904d3c8c5d81e89c4116d4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 10 May 2024 10:57:42 +0200 Subject: [PATCH 16/96] Change output gemm type to AccDataType in two stage conv bwd wei (#1283) --- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 25 ++++++++++++------- ..._conv_bwd_weight_xdl_bilinear_instance.hpp | 1 + ...t_grouped_conv_bwd_weight_xdl_bilinear.cpp | 2 ++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a5ae0565f..3c33c7dbc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle K0PerBlock, ConvBackwardWeightSpecialization>{}; + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32) + : CBlockTransferScalarPerVector_NWaveNPerXdl; + // Bytes per 32 lds bank: 32 * 4 bytes static constexpr auto BankLength = 128; static constexpr auto ElePerBank = BankLength / sizeof(ADataType); @@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ADataType, BDataType, AccDataType, - EDataType, + AccDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, @@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXdl, + WorkspaceInOutScalarPerVector, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, @@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static constexpr auto MakeElementwiseInputSequence() { return generate_sequence_v2( - [&](auto) constexpr { return Number{}; }, + [&](auto) constexpr { return Number{}; }, Number{}); } @@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); - using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); using EGridDesc_M_N = CGridDesc_M_N; static constexpr index_t ClusterLengthMPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle std::size_t GetWorkspaceSizeBytes() const { - return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; } const ADataType* p_a_grid_; @@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { - EDataType* p_c_grid = type_convert(arg.p_workspace_); + AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle GridwiseGemm, ADataType, BDataType, - EDataType, + AccDataType, OutElementwiseOperation, InElementwiseOperation, element_wise::PassThrough, @@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle }; auto launch_elementwise_kernel = [&]() { - const EDataType* p_c_grid = type_convert(arg.p_workspace_); + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle } // vector store C matrix into global memory - if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) { return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index dfd321644..8b830d91d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -86,6 +86,7 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std: //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 diff --git a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp index d733325a9..11748d471 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp @@ -264,5 +264,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); } -- GitLab From fcba889ef461bb334e8f74ea465713f5b7611855 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 11 May 2024 00:03:39 +0800 Subject: [PATCH 17/96] [CK_TILE] fix some rand number init (#1287) * add random norm * normalized default to 0/3 * change squant->auto --- example/ck_tile/01_fmha/README.md | 11 ++-- example/ck_tile/01_fmha/fmha_fwd.cpp | 87 ++++++++++++++++++---------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index fd5690a79..a3248e2a5 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -44,9 +44,9 @@ args: -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) - -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) - 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, - scale_o according to range_q, range_k, range_v, range_p, range_o + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) @@ -64,8 +64,11 @@ args: -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) - -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 686633bb2..74cb3657e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[]) .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") - .insert( - "squant", - "0", - "if using static quantization fusion or not. 0: original flow(not prefered)\n" - "1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" - "scale_o according to range_q, range_k, range_v, range_p, range_o") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") .insert("iperm", "1", "permute input\n" @@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[]) .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") - .insert( - "init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") + .insert("init", + "uf", + "init method. ui, uniform random int, ni, normalized random int\n" + "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " + "quantization") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -107,7 +112,7 @@ auto create_args(int argc, char* argv[]) // different threshold for different dtype template -auto get_elimit(int /*init_method*/) +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; double atol = 1e-3; @@ -115,9 +120,15 @@ auto get_elimit(int /*init_method*/) } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string init_method) { - if(init_method == 0) + if(init_method == "ui" || init_method == "ni") + { + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); + } + else if(init_method == "nf") { double rtol = 1e-2; double atol = 1e-2; @@ -132,9 +143,9 @@ auto get_elimit(int init_method) } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string init_method) { - if(init_method == 0) + if(init_method == "ui" || init_method == "ni") { unsigned max_rounding_point_distance = 0; double atol = 2e-3; @@ -182,15 +193,18 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - bool squant = arg_parser.get_bool("squant"); - if constexpr(!std::is_same_v) - { - if(squant) + std::string squant_str = arg_parser.get_str("squant"); + bool squant = [&]() { + if(squant_str == "auto") { - std::cerr << "static quantization only support fp8 for now" << std::endl; - return false; + if(data_type == "fp8") + return true; + else + return false; } - } + else + return atoi(squant_str.c_str()) != 0 ? true : false; + }(); float range_q = arg_parser.get_float("range_q"); float range_k = arg_parser.get_float("range_k"); @@ -217,7 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias_info bias = bias_info::decode(arg_parser.get_str("bias")); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); - int init_method = arg_parser.get_int("init"); + std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); if(*seed == 0) { @@ -319,28 +333,43 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - if(init_method == 0) + if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } - else if(init_method == 1) + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "uf" || init_method == "1") { ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } - else if(init_method == 2) + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } + else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); ck_tile::FillTrigValue{}(v_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == 3) // suitable for fp8 quantization + else if(init_method == "ufq" || init_method == "uf:q" || + init_method == "3") // suitable for fp8 quantization { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); -- GitLab From 566b6480a2e6e1245033f256eca0dce097bd5d75 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 10 May 2024 09:41:39 -0700 Subject: [PATCH 18/96] Code clean-up (#1285) * code clean-up * remove the profiling output samples --- CMakeLists.txt | 6 +- Jenkinsfile | 39 ++++----- client_example/25_wrapper/wrapper_img2col.cpp | 1 - example/01_gemm/README.md | 14 ---- example/02_gemm_bilinear/README.md | 17 ---- example/04_gemm_add_add_fastgelu/README.md | 13 --- example/09_convnd_fwd/README.md | 14 ---- example/15_grouped_gemm/README.md | 16 ---- example/26_contraction/README.md | 11 --- .../30_grouped_conv_fwd_multiple_d/README.md | 12 --- example/46_gemm_add_multiply/README.md | 16 ---- include/ck/ck.hpp | 2 +- include/ck/host_utility/device_prop.hpp | 6 +- ...d_contraction_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_batched_gemm_multiple_d_dl.hpp | 2 +- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 4 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 4 +- .../device/impl/device_fpAintB_gemm_wmma.hpp | 2 +- .../gpu/device/impl/device_gemm_dl.hpp | 4 +- .../gpu/device/impl/device_gemm_dpp.hpp | 2 +- .../device/impl/device_gemm_multiple_d_dl.hpp | 2 +- .../device_gemm_multiple_d_wmma_cshuffle.hpp | 2 +- .../gpu/device/impl/device_gemm_wmma.hpp | 2 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 2 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 2 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 2 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 4 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_grouped_gemm_multiple_d_dl.hpp | 2 +- ...e_grouped_query_attention_forward_wmma.hpp | 4 +- ...ice_multi_query_attention_forward_wmma.hpp | 4 +- .../gpu/grid/block_to_ctile_map.hpp | 2 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 2 +- include/ck/utility/amd_xdlops.hpp | 2 +- include/ck/utility/type_convert.hpp | 2 +- profiler/README.md | 83 ------------------- script/test_convnd_fwd.sh | 2 +- .../test_grouped_convnd_bwd_weight.cpp | 8 +- 38 files changed, 57 insertions(+), 259 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e3113a31d..c23746e7f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,7 +202,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) +option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -210,10 +210,10 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_NAVI3X) +if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") + message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() ## Threads diff --git a/Jenkinsfile b/Jenkinsfile index d334549bb..75800bfc9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -515,30 +515,25 @@ def Build_CK(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //check whether running on Navi or MI300 node - def navi_node = 0 - def mi300_node = 0 + //check whether to run performance tests on this node + def do_perf_tests = 0 sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ - navi_node = 1 - echo "This is a Navi node" - } - if ( runShell('grep -n "gfx942" rocminfo.log') ){ - mi300_node = 1 - echo "This is MI300 node" + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){ + do_perf_tests = 1 + echo "Stash profiler and run performance tests" } cmake_build(conf) dir("build"){ //run tests and examples sh 'make -j check' - if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ + if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on Navi or MI300 nodes + //do not stash profiler on nodes where we don't need to run performance tests sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' stash name: "ckProfiler.tar.gz" } - if (params.RUN_FULL_QA && mi300_node == 0 ){ - // build deb packages for all MI100/200/300 targets and prepare to export + if (params.RUN_FULL_QA && do_perf_tests == 0 ){ + // build deb packages for all gfx9 targets and prepare to export sh 'make -j package' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb' @@ -546,7 +541,7 @@ def Build_CK(Map conf=[:]){ stash name: "ckprofiler_0.2.0_amd64.deb" } } - if (params.hipTensor_test && navi_node == 0 ){ + if (params.hipTensor_test && do_perf_tests == 0 ){ //build and test hipTensor sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip @@ -814,7 +809,7 @@ pipeline { { parallel { - stage("Run Codegen Tests on MI200") + stage("Run Codegen Tests on gfx90a") { when { beforeAgent true @@ -865,7 +860,7 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on MI300") + stage("Build CK and run Tests on gfx942") { when { beforeAgent true @@ -885,7 +880,7 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on MI200") + stage("Build CK and run Tests on gfx90a") { when { beforeAgent true @@ -925,13 +920,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on Navi21") + stage("Build CK and run Tests on gfx1030") { when { beforeAgent true expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } - agent{ label rocmnode("navi21") } + agent{ label rocmnode("gfx1030") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -945,13 +940,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on Navi32") + stage("Build CK and run Tests on gfx1101") { when { beforeAgent true expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } - agent{ label rocmnode("navi32") } + agent{ label rocmnode("gfx1101") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index 2a4034d62..ceccc5eb8 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -181,4 +181,3 @@ int main(int argc, char* argv[]) {1, 1, 1} /*filter_dilations*/); return 0; } -// MI100 Perf: 0.255178 ms, 1698.9 GB/s, diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md index 226783b03..a09e69255 100644 --- a/example/01_gemm/README.md +++ b/example/01_gemm/README.md @@ -7,17 +7,3 @@ #arg3: run kernel # of times (>1) ./bin/example_gemm_xdl 0 1 5 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s -``` diff --git a/example/02_gemm_bilinear/README.md b/example/02_gemm_bilinear/README.md index 9eb87e1e3..a407ce24f 100644 --- a/example/02_gemm_bilinear/README.md +++ b/example/02_gemm_bilinear/README.md @@ -9,20 +9,3 @@ #arg11 to 12: alpha, beta ./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 ``` -Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s -error: 0 -max_diff: 0, 558.5, 558.5 -``` diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md index 08a55fb9a..7b0d003e5 100644 --- a/example/04_gemm_add_add_fastgelu/README.md +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -8,16 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> -``` diff --git a/example/09_convnd_fwd/README.md b/example/09_convnd_fwd/README.md index 9ab5fee54..22f90ea29 100644 --- a/example/09_convnd_fwd/README.md +++ b/example/09_convnd_fwd/README.md @@ -16,17 +16,3 @@ # , (ie RightPy, RightPx for 2D) ./bin/example_convnd_fwd_xdl 0 1 100 ``` - -Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) -``` -input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{432, 165888, 4} -arg.b_grid_desc_k0_n_k1_{432, 256, 4} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 100 times... -Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s -``` diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md index c83b23e08..a2afe0f4b 100644 --- a/example/15_grouped_gemm/README.md +++ b/example/15_grouped_gemm/README.md @@ -7,19 +7,3 @@ #arg3: run kernel # of times (>1) ./bin/example_grouped_gemm_xdl_fp16 0 1 5 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} -gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} -gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} -gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} -group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} -group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} -group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} -group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} -launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> -``` diff --git a/example/26_contraction/README.md b/example/26_contraction/README.md index c88d93cf8..acbfa84df 100644 --- a/example/26_contraction/README.md +++ b/example/26_contraction/README.md @@ -7,14 +7,3 @@ #arg3: time kernel (0=no, 1=yes) ./bin/example_contraction_bilinear_xdl_fp32 1 1 1 ``` - -Result (MI100 @ dynammic freq, 46TFlops peak FP32) -``` -a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} -c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> -``` diff --git a/example/30_grouped_conv_fwd_multiple_d/README.md b/example/30_grouped_conv_fwd_multiple_d/README.md index 7a0cb2d0e..1165634e1 100644 --- a/example/30_grouped_conv_fwd_multiple_d/README.md +++ b/example/30_grouped_conv_fwd_multiple_d/README.md @@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims): ./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1 ``` -Result (MI100) -``` -in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192} -wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192} -bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default> -``` diff --git a/example/46_gemm_add_multiply/README.md b/example/46_gemm_add_multiply/README.md index ee5cdee36..e2de4696f 100644 --- a/example/46_gemm_add_multiply/README.md +++ b/example/46_gemm_add_multiply/README.md @@ -8,19 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_multiply_dl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.e_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1> -``` diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index c8025f53c..55f562061 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -236,7 +236,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 #else -// enable only on MI200 +// enable only for gfx90a #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 13e526875..116bb3ea0 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported() ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; } -inline bool is_navi1_supported() +inline bool is_gfx101_supported() { return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || ck::get_device_name() == "gfx1012"; } -inline bool is_navi2_supported() +inline bool is_gfx103_supported() { return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } -inline bool is_navi3_supported() +inline bool is_gfx11_supported() { return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index d35645c06..a15759559 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index b01e029c0..8fd14afc0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) { @@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 5d9f8a178..149aca7e3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1392,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || - ck::is_navi3_supported())) + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index 4385d64c1..bf96324d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index 515892142..d3af5e63d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -535,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index a7f230529..93ab8a7e1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index b0e0e6da7..6f74838fb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index b9436c21a..211185dfb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index c3023301f..7cfbd8a8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || - ck::is_navi2_supported() || ck::is_navi3_supported())) + ck::is_gfx103_supported() || ck::is_gfx11_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index d731e5dda..6a4d97d7d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 37c5b5c91..a88c7b4fb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -673,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) { @@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index b7551e78a..4e14ed3a5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index d92f504d5..84b00fcbd 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_Grouped_M00_N0_M01Adapt diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 70fbcec10..565195f53 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -95,7 +95,7 @@ struct wmma_type{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // * Fixed on gfx11, Will be wave mode dependent for future architectures static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; // * num_acc_vgprs_per_wave alone M direction diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0ee52b957..d8ccb2ea7 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -4,7 +4,7 @@ #pragma once namespace ck { -// Define the common macro for MI300 models +// Define the common macro for gfx94x models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index be74b1fdc..382b9c555 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -8,7 +8,7 @@ #include "ck/utility/random_gen.hpp" namespace ck { -// Define the common macro for MI300 models +// Define the common macro for gfx94x models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif diff --git a/profiler/README.md b/profiler/README.md index a4daefba9..10febcabd 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -13,15 +13,6 @@ ./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -```bash -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -.... -Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s -``` - ## Profile 2D forward convolution kernels ```bash #arg1: tensor operation (conv=Convolution) @@ -37,15 +28,6 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) - -```bash -in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -.... -Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s -``` ## Profile contraction kernels ```bash @@ -71,16 +53,6 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ./bin/ckProfiler contraction_bilinear 0 0 2 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ``` -Result (MI100) -```bash -a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384} -d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -.... -Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s -``` - ## Profile batched gemm multiple D kernels ```bash #arg1: tensor operation (batched_gemm_multi_d=Batched GEMM multi D); @@ -99,14 +71,6 @@ Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s ./bin/ckProfiler batched_gemm_multi_d 0 1 0 0 0 1 4096 4096 4096 4096 4096 4096 16777216 16777216 16777216 16 ``` -Result (Radeon RX 6800 XT) -```bash -arg.a_grid_desc_k0_m0_m1_k1_{2048, 4096, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.e_grid_desc_m_n_{ 4096, 4096} -.... -Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s -``` ## Profile grouped convolution backward data kernels ```bash # arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) @@ -134,20 +98,6 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s ``` -Result (MI100, FP16, GNHWC_GKYXC_GNHWK) - -```bash -out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} -wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192} -in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} -.... -Best configuration parameters: -name: DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<256, 128, 256, 32, 8, 2, Default, 32, 32, 2, 4, 8, 4, 1, 1> -avg_time: 0.768321 -tflops: 86.6679 -GB/s: 127.947 -``` - ## Profile grouped convolution backward weight kernels ```bash # arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight) @@ -179,19 +129,6 @@ GB/s: 127.947 ``` -Result (MI100, FP16, GNHWC_GKYXC_GNHWK) - -```bash -input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024} -weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024} -output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512} -.... -Best configuration parameters: -name: DeviceGroupedConvBwdWeight_Xdl_CShuffle<256, 256, 128, 4, Default, 8, 4, 2, 8, 4, 8, 2, 1, 1, 8> -avg_time: 68.5216 -tflops: 95.337 -GB/s: 69.2301 -``` Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. ## Profile image to column/column to image kernels @@ -224,17 +161,6 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate ``` -Result (MI210, FP32, NHWC) - -```bash -input: dim 5, lengths {1, 256, 512, 28, 28}, strides {102760448, 401408, 1, 14336, 512} -output: dim 2, lengths {173056, 4608}, strides {4608, 1} -.... -Best configuration parameters: -name: DeviceImageToColumn<128, 32, 64, 4> -avg_time: 3.12326 -GB/s: 2042.59 -``` Note: Column to image kernel adds to the output memory, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. ## Profile Permute scale kernels @@ -254,12 +180,3 @@ Note: Column to image kernel adds to the output memory, this will cause output b ################ op datatype verify init log time dim0 dim1 dim2 in_stride0 in_stride1 in_stride2 out_stride0 out_stride1 out_stride2 ./bin/ckProfiler permute_scale 0 1 1 0 1 64 64 64 4096 64 1 1 64 4096 ``` - -Result (MI100, FP32) - -```bash -A: dim 3, lengths {64, 64, 64}, strides {4096, 64, 1} -B: dim 3, lengths {64, 64, 64}, strides {1, 64, 4096} -.... -Best perf = 0.0146878 ms, 142.782 GB/s, DeviceElementwiseNormalizationImpl<3, 2> -``` diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh index 1bd7a6b5d..8bd2c2fc3 100644 --- a/script/test_convnd_fwd.sh +++ b/script/test_convnd_fwd.sh @@ -65,7 +65,7 @@ set -- "${POSITIONAL[@]}" # restore positional parameters # NUMACTL="numactl --cpunodebind=1 --membind=1" NUMACTL= # ENV_CONF= -GPU=mi100 +GPU=gfx908 PROF_ITER_COUNT=10000 LOG_DIR_PATH=../log/${LOG_DIR} set -x diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index d100fb107..1c8082645 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -55,14 +55,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { - // on navi3x only support for 3d is implemented + // on gfx11 only support for 3d is implemented if constexpr(NDimSpatial{} != 3) { return true; } - // on navi3x only support for i8 and fp16 is implemented + // on gfx11 only support for i8 and fp16 is implemented if constexpr(!((std::is_same_v && std::is_same_v && std::is_same_v) || @@ -80,7 +80,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } else { - // support for i8 is only implemented on navi3x + // support for i8 is only implemented on gfx11 if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { -- GitLab From 7843a8a7fbd0afc49cd4b8fa0766ea9906174b0d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 10 May 2024 22:48:28 -0700 Subject: [PATCH 19/96] re-enable convnd_fwd_xdl_fp64 testing (#1289) --- example/09_convnd_fwd/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index c57679827..8a295d14c 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -3,8 +3,7 @@ add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) -- GitLab From 3e3471d5d28e857c8d95dece447b2f9e18c90b4c Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 15 May 2024 10:03:39 +0200 Subject: [PATCH 20/96] Add unit tests for grouped gemm two stage (#1256) * add unit tests for grouped gemm two stage * add reviewers suggestions --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 11 ++-- test/grouped_gemm/CMakeLists.txt | 6 ++ ...d_gemm_two_stage_multiple_d_splitk_xdl.cpp | 62 +++++++++++++++++++ .../test_grouped_gemm_two_stage_ut_cases.inc | 61 ++++++++++++++++++ test/grouped_gemm/test_grouped_gemm_util.hpp | 55 +++++++++++++++- 5 files changed, 189 insertions(+), 6 deletions(-) create mode 100644 test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp create mode 100644 test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index cb32587d3..a70ee6f05 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage elementwise_d_grid_descs_m_n_.reserve(group_count_); ds_grid_pointer_.reserve(group_count_); group_grid_size_.reserve(group_count_); + e_ptrs_.reserve(group_count_); for(std::size_t i = 0; i < gemm_descs.size(); ++i) { @@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t block_end = grid_size_ + grid_size_grp; grid_size_ += grid_size_grp; - group_grid_size_[i] = grid_size_grp; + group_grid_size_.push_back(grid_size_grp); // block-to-e-tile map auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); @@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); ds_grid_pointer_.push_back(p_ds_grid); + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_.push_back(p_Es[i]); } - // Store a copy of E pointers for elementwise kernel destination - e_ptrs_ = p_Es; } /** @@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage dim3(BlockSize), 0, cast_pointer_to_constant_address_space(dev_gemm_args), - arg.group_count_, + arg.gemm_kernel_args_.size(), arg.a_element_op_, arg.b_element_op_, PassThrough{}); // Elementwise kernels - for(int i = 0; i < arg.group_count_; ++i) + for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { time += launch_and_time_kernel( stream_config, diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f47685cf9..55cb20977 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -6,6 +6,12 @@ if(result EQUAL 0) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() +add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk) +endif() + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp new file mode 100644 index 000000000..67ecbaea3 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RRR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN, + RRR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK, + RCR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16, + RRR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16, + RCR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8, + RRR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8, + RCR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" +#include "test_grouped_gemm_two_stage_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc new file mode 100644 index 000000000..40d48f4ec --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +TEST_P(RRR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 50f423ada..9e1395b9f 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" namespace ck { namespace test { @@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam } }; +template +class TestGroupedGemmTwoStage : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + template Date: Wed, 15 May 2024 23:06:50 +0800 Subject: [PATCH 21/96] remove operator-deref (#1291) --- include/ck_tile/core/numeric/integral_constant.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index ea7a67abc..33c24da8c 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+) CK_TILE_LEFT_UNARY_OP(-) CK_TILE_LEFT_UNARY_OP(~) CK_TILE_LEFT_UNARY_OP(!) -CK_TILE_LEFT_UNARY_OP(*) CK_TILE_BINARY_OP(+) CK_TILE_BINARY_OP(-) -- GitLab From c44137838e2cb30bbe5a3b9903c357b476a34d52 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 15 May 2024 08:08:17 -0700 Subject: [PATCH 22/96] remove wrong use of nonexistent class members (#1290) --- .../ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 701dd04f6..e5e6245cb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -795,11 +795,6 @@ struct BlockwiseGemmXdlops_v2 "wrong!"); } - __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) - : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) - { - } - // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() { -- GitLab From aaa8dfdae90f21cd93de170457bebc72a933566d Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 17 May 2024 17:19:17 +0800 Subject: [PATCH 23/96] Fix compile error (#1292) error: no viable conversion from returned value of type '__half' to function return type 'fp16_hip_t' (aka '_Float16') Co-authored-by: carlushuang --- include/ck_tile/core/numeric/half.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index c616b6939..752145f71 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x) CK_TILE_HOST_DEVICE constexpr fp16_hip_t float_to_fp16_hip(const float& x) { - return __float2half(x); - // return static_cast(x); + // return __float2half(x); + return static_cast(x); } CK_TILE_HOST_DEVICE -- GitLab From 6637a810d0d4f80d1491762d90253fa1540b4cfa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 17 May 2024 07:44:48 -0700 Subject: [PATCH 24/96] Bump rocm-docs-core from 1.1.1 to 1.1.2 in /docs/sphinx (#1293) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 1.1.1 to 1.1.2. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v1.1.1...v1.1.2) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index dc1824931..f7843bd30 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.1.1 +rocm-docs-core==1.1.2 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 9a451d970..02d5f6501 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.1.1 +rocm-docs-core==1.1.2 # via -r requirements.in six==1.16.0 # via -- GitLab From 1274861a9da0d3051a9f5177a3640464b4c79d6a Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 17 May 2024 10:42:51 -0700 Subject: [PATCH 25/96] replace the ENV macro with CK_ENV (#1296) --- include/ck/host_utility/flush_cache.hpp | 6 +++--- include/ck/host_utility/kernel_launch.hpp | 8 ++++---- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 2 +- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 2 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 2 +- ...ice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 2 +- ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 2 +- ...shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 2 +- ...onv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 2 +- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 2 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 2 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 2 +- ...device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 2 +- .../gpu/device/impl/device_gemm_dl.hpp | 2 +- .../impl/device_gemm_reduce_xdl_cshuffle.hpp | 2 +- .../device_gemm_xdl_layernorm_cshuffle.hpp | 2 +- .../impl/device_gemm_xdl_skip_b_lds.hpp | 2 +- .../device_grouped_gemm_multiple_d_dl.hpp | 2 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 8 ++++---- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 2 +- .../device/impl/device_grouped_gemm_xdl.hpp | 2 +- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 4 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 20 +++++++++---------- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 18 ++++++++--------- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 20 +++++++++---------- include/ck/utility/env.hpp | 2 +- .../profile_grouped_gemm_fixed_nk_impl.hpp | 2 +- .../profiler/profile_grouped_gemm_impl.hpp | 2 +- .../profile_grouped_gemm_tile_loop_impl.hpp | 2 +- .../profile_grouped_gemm_two_stage_impl.hpp | 2 +- 30 files changed, 65 insertions(+), 65 deletions(-) diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 36993d0ae..041428e6a 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #define MEDIAN 1 if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -142,7 +142,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { return 0.0; } - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } @@ -186,7 +186,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, total_time += cur_time; #endif - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 1cdb7f9c5..a616433ac 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,7 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -41,7 +41,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } @@ -95,7 +95,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 4521b2161..6ab1669e3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -587,7 +587,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle BatchStrideD1s, BatchStrideE1} { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 37ebe2f85..34b1d503a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -658,7 +658,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 445467be5..e178b8f52 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -719,7 +719,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { arg.Print(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 6fd8c0323..0b73317c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -516,7 +516,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index f5c1460f5..13eb23574 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -644,7 +644,7 @@ struct float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 9015f640a..28778d825 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index e815c0784..7fa231d4f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 760e2840d..3be7313d2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index de4871939..6e6921351 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 149aca7e3..b84e18130 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 439872455..de8f35a64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index d3af5e63d..b1784b385 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The group count is not equal to sum of skipped groups " "and kernel args size!" @@ -836,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[" << __func__ << "] group id: " << i << " has invalid GridwiseGemm settings!" << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 403bc7fad..36cbd1cd2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop GridwiseGemm::template CheckTensorTransfersValidity( M, N, K))) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K << "] are not supported by current template parameters!" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 90c0593b2..658f32351 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The group count is not equal to sum of skipped groups " "and kernel args size!" @@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK, bhalf_t>::value) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index fdafa9ca5..aea1f5d38 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ @@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " @@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index f2eeaf7e3..6ee279a3f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.M % MPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(!(karg.N % NPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto K_t = karg.k_batch * K0PerBlock * K1; if(!(karg.K % K_t == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ @@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " @@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " @@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 const auto num_k_loop = karg.K0Padded / K0PerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The number of k loops (" << num_k_loop << ") value is not supported by GridwiseGemm Pipeline." diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 0b6504e52..6455402dc 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -124,7 +124,7 @@ struct EnvVar #define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") -#define ENV(name) \ +#define CK_ENV(name) \ ck::env::name {} template diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 80c1c42b8..09e03de99 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 476ec37eb..0b73e4fcd 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 33e758f40..74faf15be 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp index feb0be87e..14df96d50 100644 --- a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i -- GitLab From 06b891c5c2e75f7a2c4cf71c5597fdc16c236d50 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 20 May 2024 08:34:45 -0700 Subject: [PATCH 26/96] aggregate device macros in ck_tile config header (#1297) --- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 3 +- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 3 +- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 5 ++-- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 10 +++---- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 10 +++---- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 5 ++-- include/ck_tile/core/config.hpp | 30 +++++++++++++------ include/ck_tile/core/numeric/float8.hpp | 20 ++++++------- .../ck_tile/core/tensor/tile_elementwise.hpp | 2 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 24 +++++++-------- 10 files changed, 56 insertions(+), 56 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 4cc60f283..9d5b74be6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,8 +53,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index bf8788a3b..1f60818e3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -45,8 +45,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t KBatch = 1; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 67e211ef8..499eb7eb0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,8 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -80,7 +79,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } // Assume B is Col-Major diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 82c05937c..dc45407e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -34,8 +34,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -48,7 +47,7 @@ __global__ void karg); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ( @@ -49,7 +48,7 @@ __global__ void karg.c_element_op); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) } CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); @@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); @@ -656,7 +656,7 @@ struct numeric_traits { static constexpr int exp = 4; static constexpr int mant = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 8; #else static constexpr int bias = 7; @@ -668,7 +668,7 @@ struct numeric_traits { static constexpr int exp = 5; static constexpr int mant = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 16; #else static constexpr int bias = 15; // IEEE diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 90ad94b12..48762b722 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -112,7 +112,7 @@ namespace impl { template CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cb250516f..dd164e72e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else @@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else @@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); -- GitLab From 204da9c522cebec5220bba52cd3542ebcaf99e7a Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 21 May 2024 09:52:41 -0500 Subject: [PATCH 27/96] Move grouped conv fwd client examples (#1299) * Move grouped conv fwd client examples * Update existing examples * Format --- .../07_grouped_convnd_fwd/CMakeLists.txt | 20 +- .../07_grouped_convnd_fwd/common.hpp | 304 ++++++++++++++++++ .../grouped_conv1d_fwd.cpp | 212 +----------- .../grouped_conv2d_fwd.cpp | 180 +---------- .../grouped_conv3d_fwd_bf8.cpp} | 0 .../grouped_conv3d_fwd_bf8_fp8.cpp} | 0 .../grouped_conv3d_fwd_fp8.cpp} | 0 .../grouped_conv3d_fwd_fp8_bf8.cpp} | 0 client_example/16_convnd_fwd/CMakeLists.txt | 16 - 9 files changed, 345 insertions(+), 387 deletions(-) create mode 100644 client_example/07_grouped_convnd_fwd/common.hpp rename client_example/{16_convnd_fwd/conv3d_fwd_bf8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_fp8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp} (100%) rename client_example/{16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp => 07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp} (100%) diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 710eca9f4..e8c046ff4 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) -endif() \ No newline at end of file + + if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() +endif() diff --git a/client_example/07_grouped_convnd_fwd/common.hpp b/client_example/07_grouped_convnd_fwd/common.hpp new file mode 100644 index 000000000..729af0b88 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 4983ac33c..d3a3111e9 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - std::array in_lengths{G, N, Wi, C}; - std::array in_strides{0, 0, 0, 1}; - - std::array wei_lengths{G, K, X, C}; - std::array wei_strides{0, 0, 0, 1}; - - std::array out_lengths{G, N, Wo, K}; - std::array out_strides{0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW - std::rotate(rbegin(in_lengths), - std::next(rbegin(in_lengths)), - std::next(rbegin(in_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(in_strides), - std::next(rbegin(in_strides)), - std::next(rbegin(in_strides), NumDimSpatial + 1)); - std::rotate(rbegin(wei_lengths), - std::next(rbegin(wei_lengths)), - std::next(rbegin(wei_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(wei_strides), - std::next(rbegin(wei_strides)), - std::next(rbegin(wei_strides), NumDimSpatial + 1)); - std::rotate(rbegin(out_lengths), - std::next(rbegin(out_lengths)), - std::next(rbegin(out_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(out_strides), - std::next(rbegin(out_strides)), - std::next(rbegin(out_strides), NumDimSpatial + 1)); - - std::array filter_strides{1}; - std::array filter_dilations{1}; - std::array input_left_pads{1}; - std::array input_right_pads{1}; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C + - sizeof(WeiDataType) * G * K * X * C + - sizeof(OutDataType) * G * N * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 938335062..fb8a410ab 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W static constexpr ck::index_t Ho = 28; // output H static constexpr ck::index_t Wo = 28; // output W -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space - // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW - // Hence, we need to adjust the order of stride - std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; - std::array wei_lengths{G, K, C, Y, X}; - std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; - - std::array filter_strides{1, 1}; - std::array filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; - - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * N * Ho * Wo * G * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 23311b402..5279e3dfc 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,22 +7,6 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) - target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) - target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) - target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp) - target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) -- GitLab From 7b027d5643b3e0cf15bd13ea85c4f09a0675f6c1 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 22 May 2024 11:45:27 -0700 Subject: [PATCH 28/96] Select appropriate GPU targets for instances, tests, and examples. (#1304) * set individual gpu targets for instances, examples, tests * fix path to hip compiler * fix path to hip compiler once more * aggregate device macros in ck_tile config header * fix the cmake logic for instances * fix clang format * add gfx900 and gfx906 to default set of targets --- CMakeLists.txt | 16 +++--- example/CMakeLists.txt | 35 +++++++++++-- .../gpu/CMakeLists.txt | 50 ++++++++++++++++--- test/CMakeLists.txt | 36 +++++++++++-- 4 files changed, 115 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c23746e7f..3f9e44583 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) @@ -112,7 +112,7 @@ message("checking which targets are supported") #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) @@ -135,12 +135,10 @@ endif() message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) - if(GPU_TARGETS) message("Building CK for the following targets: ${GPU_TARGETS}") else() - message("Building CK for the following targets: ${AMDGPU_TARGETS}") + message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}") endif() if (GPU_TARGETS) @@ -225,7 +223,13 @@ link_libraries(Threads::Threads) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") + +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5465adb77..fd9f5cd89 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c035e7e56..05b8c035c 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build XDL instances if gfx9 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT INST_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + set(INST_OBJ) + foreach(source IN LISTS ARGN) + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + if(source MATCHES "_xdl") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set(offload_targets) + foreach(target IN LISTS INST_TARGETS) + string(APPEND offload_targets "--offload-arch=${target} ") + endforeach() + set_source_files_properties(${source} PROPERTIES COMPILE_FLAGS ${offload_targets}) + list(APPEND INST_OBJ ${source}) + endforeach() + add_library(${INSTANCE_NAME} OBJECT ${INST_OBJ}) target_compile_features(${INSTANCE_NAME} PUBLIC) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(${INSTANCE_NAME}) @@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list}) if(NOT DEFINED DTYPES) set(add_inst 1) endif() + + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + + if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) message("quantization instances will not be built!") set(add_inst 0) @@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 25c63ac7f..49b67992b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) add_dependencies(tests ${TEST_NAME}) @@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) -- GitLab From fd72380aeb4da1aa87572f9fb801fc6c2bfabf9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 22 May 2024 21:01:01 +0200 Subject: [PATCH 29/96] Optimize grouped conv bwd weight for small M and N (#1303) * Optimize grouped conv bwd weight for small M and N * Fixes --- include/ck/host_utility/flush_cache.hpp | 23 +- .../multi_index_transform.hpp | 15 +- .../multi_index_transform_helper.hpp | 10 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 1314 ++++++++++++---- ...idwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp | 1369 +++++++++++++++++ .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 14 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 14 +- .../transform_conv_bwd_weight_to_gemm_v2.hpp | 640 ++++++++ ...conv_bwd_weight_two_stage_xdl_instance.hpp | 22 +- .../grouped_convolution_backward_weight.hpp | 8 +- ...rouped_convolution_backward_weight_xdl.inc | 28 +- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 4 +- ...nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp} | 15 +- ..._nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp | 41 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 18 +- ...wgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp} | 15 +- ...hwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp | 41 + .../test_grouped_convnd_bwd_weight.cpp | 19 +- 18 files changed, 3223 insertions(+), 387 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp create mode 100644 include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp => device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp} (81%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp => device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp} (81%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 041428e6a..9d9974d49 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -104,14 +104,19 @@ inline void flush_icache() hip_check_error(hipGetLastError()); } // if TimePrePress == false, return time does not include preprocess's time -template +template float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, - Args& args) + GemmArgs& gemm_args, + Args... args) { #if CK_TIME_KERNEL #define MEDIAN 1 @@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); } @@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, preprocess(); } // run real kernel - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); // end real kernel @@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; - printf("args.p_a_grid: %p, args.p_b_grid:%p\n", - static_cast(args.p_a_grid), - static_cast(args.p_b_grid)); + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); } } @@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, else { preprocess(); - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; } #else - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index f68473c29..c152cbfb1 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1952,7 +1952,7 @@ struct Modulo } }; -template +template struct Xor { using LowerIndex = MultiIndex<2>; @@ -1981,8 +1981,15 @@ struct Xor idx_low(Number<0>{}) = idx_up[Number<0>{}]; - idx_low(Number<1>{}) = - idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + if constexpr(ApplyModulo) + { + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + else + { + idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}]; + } } template {modulus, up_length}; } +template +__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} + template __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) { - return Xor{low_lengths}; + return Xor{low_lengths}; } } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index d30252e68..c1c159101 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -14,95 +14,137 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" namespace ck { namespace tensor_operation { namespace device { template + index_t NumBatchToMerge, + bool HasMainKBlockLoop, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + index_t MinimumOccupancy = 1, + TailNumber TailNum = TailNumber::Full> __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = batch_count; - ignore = block_2_ctile_map; - ignore = compute_ptr_offset_of_batch; - - compute_ptr_offset_of_batch.GetAPtrOffset(0); - compute_ptr_offset_of_batch.GetBPtrOffset(0); - compute_ptr_offset_of_batch.GetCPtrOffset(0); + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -121,7 +163,7 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + index_t NumBatchToMerge = 1, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle : public DeviceGroupedConvBwdWeight { + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; using ADataType = OutDataType; @@ -183,101 +232,123 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle static constexpr auto K1Number = Number{}; - static constexpr auto conv_to_gemm_transformer = + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = TransformConvBwdWeightToGemm{}; - // Bytes per 32 lds bank: 32 * 4 bytes - static constexpr auto BankLength = 128; - static constexpr auto ElePerBank = BankLength / sizeof(ADataType); - - // M1 & M0 - static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; - static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; - static constexpr auto ABlockLdsM1Padding = 4; + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; - // N1 & N0 - static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; - static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; - static constexpr auto BBlockLdsN1Padding = 4; + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } - template ::type = false> + template ::type = false> static auto GetABCGridDesc() { const ck::index_t dim = 1; const ck::index_t batch = 1; - const std::array lengths{1}; - const std::array strides{1, 1, 1, 1}; - const std::array params{1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> - static auto GetABCGridDesc() + static auto GetElementwiseCGridDesc() { const ck::index_t dim = 1; const ck::index_t batch = 1; const std::array lengths{1, 1}; const std::array strides{1, 1, 1, 1, 1}; const std::array params{1, 1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; } template ::type = false> - static auto GetABCGridDesc() + static auto GetElementwiseCGridDesc() { const ck::index_t dim = 1; const ck::index_t batch = 1; const std::array lengths{1, 1, 1}; const std::array strides{1, 1, 1, 1, 1, 1}; const std::array params{1, 1, 1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; } using ABCGridDescs = decltype(GetABCGridDesc()); @@ -285,60 +356,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using AGridDesc_K0_M_K1 = remove_cvref_t; using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< - BlockSize, - ADataType, - BDataType, - AccDataType, - AccDataType, - InMemoryDataOperationEnum::AtomicAdd, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - element_wise::PassThrough, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXdl, - NPerXdl, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - ABlockLdsM1PerBlock, - ABlockLdsM0PerBlock, - ABlockLdsM1Padding, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - BBlockLdsN1PerBlock, - BBlockLdsN0PerBlock, - BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXdl, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - true, - true, - 1, - PipelineVersion::v1, - ComputeTypeA, - ComputeTypeB>; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; + + using GridwiseGemm = + GridwiseGemm_xdl_cshuffle_v3; static constexpr index_t ClusterLengthMPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -347,8 +414,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; using GridwiseElementwise = - GridwiseElementwise, - Tuple, + GridwiseElementwise, + Tuple, Tuple, Tuple, Block2TileMapElementwise, @@ -366,10 +433,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - - using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); struct Argument : public BaseArgument { @@ -395,11 +460,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, N01_{N01}, @@ -430,7 +494,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle begin(output_spatial_lengths_)); const auto descs = - conv_to_gemm_transformer + conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( Conv_N_, Conv_K_, @@ -447,15 +511,34 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - ce_grid_desc_m_n_ = descs[I2]; + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; @@ -465,16 +548,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - ce_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( - ce_grid_desc_m_n_); - } + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); } std::size_t GetWorkspaceSizeBytes() const @@ -486,12 +564,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const BDataType* p_b_grid_; EDataType* p_e_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; - Block2CTileMap block_2_ctile_map_; Block2TileMapElementwise elementwise_block_2_ctile_map_; // for computing batch offset @@ -525,96 +603,676 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_)) + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + I0, + arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time = launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } } - const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { - AccDataType* p_c_grid = type_convert(arg.p_workspace_); - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; - - constexpr bool has_main_loop = has_main_k_block_loop.value; - - auto preprocess = [&]() { - hip_check_error(hipMemsetAsync( - p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); - }; - - const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, - BDataType, - AccDataType, - OutElementwiseOperation, - InElementwiseOperation, - element_wise::PassThrough, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; - - return launch_and_time_kernel_with_preprocess( - stream_config, - preprocess, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - p_c_grid, - arg.a_element_op_, - arg.b_element_op_, - element_wise::PassThrough{}, - arg.Conv_G_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); - }; + return ave_time; + } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { auto launch_elementwise_kernel = [&]() { const AccDataType* p_c_grid = type_convert(arg.p_workspace_); - const index_t grid_size = - arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * - arg.Conv_G_; + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; std::array in_out_batch_strides = { arg.compute_ptr_offset_of_batch_.BatchStrideC_}; const auto kernel = kernel_batched_elementwise, - ck::Tuple, + ck::Tuple, + ck::Tuple, ck::Tuple, ck::Tuple, Block2TileMapElementwise, @@ -627,8 +1285,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle dim3(grid_size), dim3(BlockSize), 0, - make_tuple(arg.ce_grid_desc_m_n_), - make_tuple(arg.ce_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), make_tuple(p_c_grid), make_tuple(arg.p_e_grid_), arg.elementwise_block_2_ctile_map_, @@ -638,16 +1296,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle in_out_batch_strides); }; - float avg_time = 0; - if(has_main_k0_block_loop) - { - avg_time = launch_gemm_kernel(integral_constant{}); - } - else - { - avg_time = launch_gemm_kernel(integral_constant{}); - } - + float avg_time = RunGemmV3(arg, stream_config); avg_time += launch_elementwise_kernel(); return avg_time; } @@ -667,6 +1316,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + // Check this here, it allows to use other instances from factory even // if workspace is not allocated if(!arg.p_workspace_) @@ -723,10 +1389,38 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } + if constexpr(NumBatchToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + return false; + } + if(arg.Conv_G_ % NumBatchToMerge != 0) + { + return false; + } + } + + if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + { + return false; + } + } + // vector load A/B matrix from global memory - if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && - arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && - arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) { return false; } @@ -737,11 +1431,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return false; } - // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_); + return true; } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -840,13 +1530,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { auto str = std::stringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock << ", " + << KPerBlock << ", " << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " << K1 << ", " << MXdlPerWave << ", " @@ -857,7 +1558,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << BBlockTransferDstScalarPerVector_K1 << ", " << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl + << CBlockTransferScalarPerVector_NWaveNPerXdl << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumBatchToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp new file mode 100644 index 000000000..d2a06ba9a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -0,0 +1,1369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index dc45407e5..50e6f68e6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -603,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -669,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -740,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -803,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 469ed1c8b..f9071bd29 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -781,8 +781,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -847,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -918,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -981,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp new file mode 100644 index 000000000..158890d7a --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -0,0 +1,640 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +/** + * @brief Transform conv bwd weight to gemm v2 + * + * This version does following things: + * 1. Merge KBatch with K0 to align descriptor with universal gemm + * 2. Merge Batch with M and N dimension. It allows to increase compute in + * case of small M and N. It also allows to vector load and store in case of + * K = 1, C = 1 and NHWGC layout. + */ +template +struct TransformConvBwdWeightToGemmV2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumBatchToMerge, C), + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[4]; + const auto BatchStride = weights_strides[0]; + // Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumBatchToMerge, K, Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumBatchToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || + NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || + NumBatchToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[5]; + const auto BatchStride = weights_strides[0]; + // Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumBatchToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || + NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || + NumBatchToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K * NumBatchToMerge; + const index_t GemmN = C * X * Y * NumBatchToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K * NumBatchToMerge; + const index_t GemmN = C * Z * X * Y * NumBatchToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 8120eff25..77d372843 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -35,14 +35,24 @@ template + ConvolutionBackwardWeightSpecialization ConvSpec, + BlockGemmPipelineScheduler Scheduler, + BlockGemmPipelineVersion PipelineVersion> using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 8, 1, 8>, 1> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 91b7df3d4..5a703e581 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( std::vector{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< - 2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp new file mode 100644 index 000000000..398c14b11 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 435d1831e..8e939c15a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,12 +1,14 @@ -# XDL_DL_WMMA_KERNELS + # XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp + ) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp similarity index 81% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp index c4849c017..4d0f1e68c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( std::vector{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp new file mode 100644 index 000000000..c5cc062f2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 1c8082645..5ef073066 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test std::vector conv_params; std::vector split_ks{1, 2}; - bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) + bool skip_case(const ck::index_t split_k) { - // Odd K or C values are supported only by DL and WMMA - // kernels (only applies to fp16) - // DL and WMMA kernels currently support only `split_k=1` - if constexpr(std::is_same_v) - { - if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) - { - return true; - } - } - // 1d NWGC is only supported by DL kernel // DL kernel is only supported for split_k=1 if constexpr(std::is_same_v && std::is_same_v) @@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test { for(auto& param : conv_params) { - if(!skip_case(param, split_k)) + if(!skip_case(split_k)) { pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_implconv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); this->Run(); } @@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); } -- GitLab From 29e58d5b28a7f8490ced9b25c17519d110f7bba7 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 21 May 2024 20:37:26 +0000 Subject: [PATCH 30/96] Make the library which generates CK instances for pytorch2 inductor's CK backend usage Also bundle the CK library and include files with the pip package. The package is pip-installable with `pip install git+https://github.com/tenpercent/composable_kernel@enable-pip` (substitute the repo path and branch if necessary) Testing: `myenv/bin/python3 -m ck4inductor.universal_gemm.gen_instances` (prints a list of instances) `tree myenv/lib/python3.12/site-packages/ck4inductor` (observe the list of sources along the installed package) --- pyproject.toml | 36 ++ python/ck4inductor/__init__.py | 0 .../universal_gemm/gen_instances.py | 570 ++++++++++++++++++ python/ck4inductor/universal_gemm/op.py | 95 +++ python/ck4inductor/util.py | 7 + 5 files changed, 708 insertions(+) create mode 100644 pyproject.toml create mode 100644 python/ck4inductor/__init__.py create mode 100644 python/ck4inductor/universal_gemm/gen_instances.py create mode 100644 python/ck4inductor/universal_gemm/op.py create mode 100644 python/ck4inductor/util.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..8e7e8607b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "rocm-composable-kernel" +dynamic = ["version"] +description = "Composable Kernel, performance-critical kernels for machine learning workloads" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/rocm/composable_kernel" +"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" + +[tool.setuptools] +packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"] + +[tool.setuptools.package-dir] +ck4inductor = "python/ck4inductor" +"ck4inductor.include" = "include" +"ck4inductor.library" = "library" + +[tool.setuptools.package-data] +"ck4inductor.include" = ["ck/**/*.hpp"] +"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"] + +[tool.setuptools.dynamic] +version = { attr = "setuptools_scm.get_version" } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py new file mode 100644 index 000000000..8b6d6b73b --- /dev/null +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -0,0 +1,570 @@ +import logging +import os +import subprocess +from dataclasses import fields, replace +from functools import lru_cache, partial +from typing import List + +from ..util import library_path + +from .op import CKGemmOperation + +log = logging.getLogger(__name__) + + +def _ck_library_dir(): + gemm_instances_path = os.path.join( + library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal" + ) + if not os.path.exists(gemm_instances_path): + log.error("CK library path %s does not exist", gemm_instances_path) + return None + return gemm_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: + """ + Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + for line in str_instances: + s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + # pad with `None`s for the fields which are not defined in the instance + new_instance = CKGemmOperation( + *template_args, # type: ignore[arg-type] + *((None,) * (len(fields(CKGemmOperation)) - len(template_args))), + ) + # the last 2 template parameters are optional + # if they are absent, substitute them with default values from Universal Gemm C++ template declaration + if new_instance.a_compute_dtype is None: + new_instance.a_compute_dtype = new_instance.c_element_dtype + if new_instance.b_compute_dtype is None: + new_instance.b_compute_dtype = new_instance.c_element_dtype + + op_instances.append(new_instance) + return op_instances + + +def default_instances() -> List[CKGemmOperation]: + # fallback: known working op instance for problem size M=2240 K=256 N=2048 + # all string attributes must be either type aliases or global constants in C++ + + return [ + CKGemmOperation( + a_layout="Row", + b_layout="Row", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + a_compute_dtype="F16", + b_compute_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + gemm_specialization="GemmSpecialization::Default", + block_size=256, + m_per_block=224, + n_per_block=256, + k_per_block=64, + a_k1=8, + b_k1=2, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, # type: ignore[arg-type] + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + b_block_transfer_thread_cluster_arrange_order=(0, 2, 1), + b_block_transfer_src_access_order=(0, 2, 1), + b_block_transfer_src_vector_dim=1, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=2, + b_block_lds_extra_n=0, # type: ignore[arg-type] + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ) + ] + + +@lru_cache(None) +def gen_ops_library() -> List[CKGemmOperation]: + """ + Parse the Universal Gemm instances defined in the composable kernel library folder. + """ + ck_library_dir = _ck_library_dir() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceGemm_Xdl_CShuffleV3", + _ck_library_dir(), + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + gemm_specs = [ + "GemmSpecialization::Default", + "GemmSpecialization::MPadding", + "GemmSpecialization::NPadding", + "GemmSpecialization::KPadding", + "GemmSpecialization::MNPadding", + "GemmSpecialization::MKPadding", + "GemmSpecialization::NKPadding", + "GemmSpecialization::MNKPadding", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + sub_spec = instance.gemm_specialization == "GemmSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + gemm_specialization=spec, + ) + ) + + return substitute_instances + + +@lru_cache(None) +def gen_ops_preselected() -> List[CKGemmOperation]: + """ + Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances + """ + ck_gemm_f16_rcr = partial( + CKGemmOperation, + a_layout="Row", + b_layout="Col", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + k_per_block=64, + a_k1=8, + b_k1=8, + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, + b_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + b_block_transfer_src_access_order=(1, 0, 2), + b_block_transfer_src_vector_dim=2, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=8, + b_block_lds_extra_n=0, + a_compute_dtype="F16", + b_compute_dtype="F16", + ) + ck_gemm_f16_rcr_compute_friendly = partial( + ck_gemm_f16_rcr, + block_size=256, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ) + ck_gemm_f16_rcr_memory_friendly = partial( + ck_gemm_f16_rcr, + block_size=128, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v2", + ) + ck_gemm_f16_rcr_latency_friendly = partial( + ck_gemm_f16_rcr, + gemm_specialization="GemmSpecialization::Default", + block_size=128, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v1", + ) + return [ + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=224, + n_per_block=256, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=64, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=64, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 64, + 1, + 2, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=16, + n_per_block=32, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=32, + n_per_block=16, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + ), + ] + + +if __name__ == "__main__": + print(gen_ops_library()) diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py new file mode 100644 index 000000000..ab541c5fb --- /dev/null +++ b/python/ck4inductor/universal_gemm/op.py @@ -0,0 +1,95 @@ +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKGemmOperation: + """ + A python dataclass storing the template parameters of a CK Universal Gemm template instance + """ + + a_layout: str + b_layout: str + c_layout: str + + a_element_dtype: str + b_element_dtype: str + c_element_dtype: str + + acc_dtype: str + c_shuffle_dtype: str + + a_elementwise_op: str + b_elementwise_op: str + c_elementwise_op: str + + gemm_specialization: str + + block_size: int + + m_per_block: int + n_per_block: int + k_per_block: int + + a_k1: int + b_k1: int + + m_per_xdl: int + n_per_xdl: int + + m_xdl_per_wave: int + n_xdl_per_wave: int + + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( + Tuple[int, int, int, int] + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block: int + + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: Optional[str] + + a_compute_dtype: Optional[str] + b_compute_dtype: Optional[str] + + def name(self): + # cpp alias for template instance + return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/util.py b/python/ck4inductor/util.py new file mode 100644 index 000000000..79d6be00f --- /dev/null +++ b/python/ck4inductor/util.py @@ -0,0 +1,7 @@ +import functools +import os + + +@functools.lru_cache(None) +def library_path(): + return os.path.join(os.path.dirname(__file__), 'library') -- GitLab From 06a9b72caf3be81b3177c036a2a7c46fb24cce2a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 23 May 2024 07:45:53 -0700 Subject: [PATCH 31/96] Bump rocm-docs-core from 1.1.2 to 1.1.3 in /docs/sphinx (#1308) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 1.1.2 to 1.1.3. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v1.1.2...v1.1.3) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index f7843bd30..9473b4aba 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.1.2 +rocm-docs-core==1.1.3 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 02d5f6501..8941877ae 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.1.2 +rocm-docs-core==1.1.3 # via -r requirements.in six==1.16.0 # via -- GitLab From ec2bae27ff2b7ac658bfb92f533d34db15977eec Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 23 May 2024 09:17:02 -0700 Subject: [PATCH 32/96] Split the gemm_multi_abd instances. (#1306) * split the gemm_multi_abd instances * update the dates --- .../gpu/gemm_multi_abd/CMakeLists.txt | 7 +- ..._abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++++++++++ ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++++++++++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 108 +----------------- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 59 ++++++++++ ...iply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++++++++++ ...bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++++++++++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 107 +---------------- ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 58 ++++++++++ 9 files changed, 357 insertions(+), 214 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt index 7c22d8681..5af7322b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -2,9 +2,14 @@ set(GEMM_MULTI_ABD_INSTANCES) list(APPEND GEMM_MULTI_ABD_INSTANCES + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp - + device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..573dcc7d7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..6833ab20f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 98546de04..7cbf55e5f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( Interwave>{}); } -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); - - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..044fc2672 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); + + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..12bcf1925 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..e4a04d48b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 5c46730ea..590e89284 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i Interwave>{}); } -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyAdd>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - Multiply>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 000000000..5741ee29a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 02fa2c298bb4aaa89f62884bef4d2b27afc454d6 Mon Sep 17 00:00:00 2001 From: Joseph Macaranas <145489236+amd-jmacaran@users.noreply.github.com> Date: Thu, 23 May 2024 18:21:34 -0400 Subject: [PATCH 33/96] Enable external CI pipeline triggers (#1310) --- .azuredevops/rocm-ci.yml | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .azuredevops/rocm-ci.yml diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml new file mode 100644 index 000000000..8c5285675 --- /dev/null +++ b/.azuredevops/rocm-ci.yml @@ -0,0 +1,42 @@ +resources: + repositories: + - repository: pipelines_repo + type: github + endpoint: ROCm + name: ROCm/ROCm + +variables: +- group: common +- template: /.azuredevops/variables-global.yml@pipelines_repo + +trigger: + batch: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + +pr: + autoCancel: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + drafts: false + +jobs: + - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo -- GitLab From 5055b3bdcb5a7f0b8f359b606d3c5b75efd6df54 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 28 May 2024 11:13:21 +0800 Subject: [PATCH 34/96] [CK_TILE] support group from cmdline (#1295) * support cmdline seqlen decode * silent print * update readme * update kernel launch 3d * update tile partitioner * fix spill for bf16 * modify based on comment * modify payload_t * fix bug for alibi mode * fix alibi test err * refactor kernel launch, support select timer * add missing file * remove useless code * add some comments --- example/ck_tile/01_fmha/README.md | 1 + example/ck_tile/01_fmha/fmha_fwd.cpp | 79 +++++-- example/ck_tile/01_fmha/generate.py | 24 ++- example/ck_tile/01_fmha/script/smoke_test.sh | 1 + example/ck_tile/01_fmha/utils.hpp | 102 ++++++++- .../core/arch/amd_buffer_addressing.hpp | 39 +++- include/ck_tile/core/config.hpp | 4 + include/ck_tile/host.hpp | 1 + include/ck_tile/host/device_memory.hpp | 59 +++-- include/ck_tile/host/kernel_launch.hpp | 202 ++++++------------ include/ck_tile/host/stream_config.hpp | 17 ++ include/ck_tile/host/timer.hpp | 79 +++++++ .../fmha/block/block_position_encoding.hpp | 6 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 4 +- .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 59 ++++- .../position_embedding/position_embedding.cpp | 124 +++++------ 16 files changed, 537 insertions(+), 264 deletions(-) create mode 100644 include/ck_tile/host/timer.hpp diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index a3248e2a5..0bb540877 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -34,6 +34,7 @@ args: if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 74cb3657e..91fc07d83 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[]) "-1", "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") - .insert("s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", @@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[]) "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k < 0) - seqlen_k = seqlen_q; + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad")); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) @@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bool lse = arg_parser.get_bool("lse"); bias_info bias = bias_info::decode(arg_parser.get_str("bias")); - mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); @@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - ck_tile::stream_config stream_config{ - nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; - const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); using TypeConfig = FmhaFwdTypeConfig; @@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); @@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser) v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const std::string prec = arg_parser.get_str("prec"); std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << std::flush; @@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return ck_tile::identity{}; }(); - auto fmha_args = [&]() { + auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), - nullptr, + k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), shape_seqlen_q, shape_seqlen_k, batch, @@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser) // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; @@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else { return ck_tile::Alibi{ - 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; } }(); @@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser) for(auto i_h = 0; i_h < nhead; i_h++) { SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); - alibi_host.slope = current_slope; + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; for(auto i_r = 0; i_r < real_seqlen_q; i_r++) { for(auto i_c = 0; i_c < real_seqlen_k; i_c++) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 51fecd07b..f0180d6db 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -78,6 +78,11 @@ BOOL_MAP = { "f" : "false" } +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder @@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_dvpad}, {F_bias}, {F_lse}, - {F_squant}, + {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; @@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel, + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; @@ -154,7 +159,7 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ @@ -389,6 +394,12 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + @property def template(self) -> str: kernel_body = str() @@ -413,7 +424,7 @@ class FmhaFwdKernel: F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -421,12 +432,13 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 2c4bb562a..21f679e11 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index e10ae617d..737efd825 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -4,12 +4,14 @@ #pragma once #include +#include #include #include #include #include #include #include +#include #include "ck_tile/core/container/span.hpp" @@ -37,12 +39,14 @@ std::vector to_seqstarts(ck_tile::span seqlens) std::vector generate_seqlens(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, // if not negative, clamp max std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens(count, seqlens_sum); + std::vector seqlens( + count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); if(mode == mode_enum::group && 1 < count) { @@ -55,7 +59,7 @@ std::vector generate_seqlens(mode_enum mode, std::uniform_int_distribution step_dist(1, count - 1); auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); // make sure each elements of seqlens is always greater than 0 @@ -66,6 +70,11 @@ std::vector generate_seqlens(mode_enum mode, const size_type to_increase = (to_decrease + next_step()) % count; + if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) + { + continue; + } + --seqlens[to_decrease]; ++seqlens[to_increase]; } @@ -76,10 +85,91 @@ std::vector generate_seqlens(mode_enum mode, std::vector generate_seqstarts(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, std::optional seed = std::nullopt) { - return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + auto s_q = std::vector(batch, q); + auto s_k = std::vector(batch, k < 0 ? q : k); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); + auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ } int env_get_int(const char* var_name, int default_int) @@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) char* v = getenv(var_name); int r = default_int; if(v) - r = atoi(v); + r = std::atoi(v); return r; } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 53f42a742..ac2f0cab9 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz return __builtin_bit_cast(int32x4_t, res); } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct buffer_load_trait; + +template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct buffer_load_trait<4 , T> { using payload_t = float; }; +template struct buffer_load_trait<2 , T> { using payload_t = float; }; +template struct buffer_load_trait<1 , T> { using payload_t = float; }; + +#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA +template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; +template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; +template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; +#endif +// clang-format on +} // namespace impl + // TODO: glc/slc/... template struct buffer_load; @@ -48,7 +67,7 @@ struct buffer_load<16> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -68,7 +87,7 @@ struct buffer_load<8> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -88,7 +107,7 @@ struct buffer_load<4> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -108,7 +127,7 @@ struct buffer_load<2> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -128,7 +147,7 @@ struct buffer_load<1> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -152,7 +171,7 @@ struct buffer_load_if<16> { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" @@ -177,7 +196,7 @@ struct buffer_load_if<8> { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" @@ -201,7 +220,7 @@ struct buffer_load_if<4> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" @@ -225,7 +244,7 @@ struct buffer_load_if<2> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" @@ -249,7 +268,7 @@ struct buffer_load_if<1> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 601aad19b..10045d8f7 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -171,3 +171,7 @@ #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 #endif + +#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA +#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 +#endif diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0c4a77822..98a3bb7d7 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,3 +20,4 @@ #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 91463a06a..7c8549f74 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -27,7 +27,14 @@ struct DeviceMem DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void Realloc(std::size_t mem_size) { @@ -36,7 +43,14 @@ struct DeviceMem HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void* GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t GetBufferSize() const { return mMemSize; } @@ -47,15 +61,18 @@ struct DeviceMem HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } - else - { - throw std::runtime_error("ToDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("ToDevice with an empty pointer"); + // } } void ToDevice(const void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR( - hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + } } void FromDevice(void* p) const { @@ -63,14 +80,17 @@ struct DeviceMem { HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } - else - { - throw std::runtime_error("FromDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("FromDevice with an empty pointer"); + // } } void FromDevice(void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } } void SetZero() const { @@ -82,13 +102,16 @@ struct DeviceMem template void SetValue(T x) const { - if(mMemSize % sizeof(T) != 0) + if(mpDeviceBuf) { - throw std::runtime_error("wrong! not entire DeviceMem will be set"); - } + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } - // TODO: call a gpu kernel to set the value (?) - set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + // TODO: call a gpu kernel to set the value (?) + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } } ~DeviceMem() { diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 7053888ab..e9c5a0c25 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/timer.hpp" #include #include @@ -14,153 +15,92 @@ template -CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) +// +// return a anonymous functor(lambda) to be called later +// the KernelImpl should be a class without non-static data member, or let's say +// can be instantiate with "KernelImpl{}" +// +// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl +// +template +CK_TILE_HOST auto +make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { - // warm up - for(int i = 0; i < s.cold_niters_; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - const int nrepeat = s.nrepeat_; - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + const auto kernel = kentry; - return total_time / nrepeat; - } - else - { + return [=](const stream_config& s) { kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; -#endif + }; } -template -CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, - PreProcessFunc preprocess, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) +// clang-format off +/* + * launch_kernel() + * + * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config) + * the callables should have signature as "operator()(const stream_config& s){ ... }" to call + * + * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }" + * as signature, for the callable (pay attention to the capture list) + * + * e.g. + * ck_tile::launch_kernel(s, + * [=](const stream_config& s){ hipMemset(ptr, 0, size) }, + * [=](const stream_config& s){ some_kernel<<>>(arg); } + * ); + * + * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}") + * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you, + * then pass it to ck_tile::launch_kernel() + * + * e.g. + * ck_tile::launch_kernel(s, + * ck_tile::make_kernel(kernel_0{}, grids0, blocks0, 0, kargs0), + * ck_tile::make_kernel(kernel_1{}, grids1, blocks1, 0, kargs1), + * ...); + **/ +// clang-format on +template +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) { -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { -#if CK_TILE_DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif - // warm up - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - const int nrepeat = 10; -#if CK_TILE_DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); + // clang-format off + if(!s.time_kernel_) { + (callables(s),...); hip_check_error(hipGetLastError()); + return 0; + } + if(s.is_gpu_timer_) { + gpu_timer timer {}; - for(int i = 0; i < nrepeat; ++i) - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); - float total_time = 0; + return timer.duration() / s.nrepeat_; + } + else { + cpu_timer timer {}; - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); - return total_time / nrepeat; - } - else - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); - return 0; + return timer.duration() / s.nrepeat_; } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; -#endif + // clang-format on } -template -CK_TILE_HOST float launch_kernel(const stream_config& s, - KernelImpl kernel_impl, - dim3 grid_dim, - dim3 block_dim, - std::size_t dynamic_smem_byte, - Args... args) -{ - const auto kernel = kentry; - - return launch_and_time_kernel( - s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...); -} } // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index d29c6f0fa..47cf0fd5e 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -6,6 +6,22 @@ #include namespace ck_tile { +/* + * construct this structure with behavior as: + * + * // create stream config with default stream(NULL), and not timing the kernel + * stream_config s = stream_config{}; + * + * // create stream config with _some_stream_id_, and not timing the kernel + * stream_config s = stream_config{_some_stream_id_}; + * + * // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default + * stream_config s = stream_config{_some_stream_id_, true}; + * + * // create stream config with _some_stream_id_, and benchmark using cpu timer + * stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false}; + **/ + struct stream_config { hipStream_t stream_id_ = nullptr; @@ -13,5 +29,6 @@ struct stream_config int log_level_ = 0; int cold_niters_ = 3; int nrepeat_ = 10; + bool is_gpu_timer_ = true; // keep compatible }; } // namespace ck_tile diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp new file mode 100644 index 000000000..e2baeaef7 --- /dev/null +++ b/include/ck_tile/host/timer.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include +#include + +namespace ck_tile { + +struct gpu_timer +{ + CK_TILE_HOST gpu_timer() + { + HIP_CHECK_ERROR(hipEventCreate(&start_evt)); + HIP_CHECK_ERROR(hipEventCreate(&stop_evt)); + } + + CK_TILE_HOST ~gpu_timer() noexcept(false) + { + HIP_CHECK_ERROR(hipEventDestroy(start_evt)); + HIP_CHECK_ERROR(hipEventDestroy(stop_evt)); + } + + CK_TILE_HOST void start(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); + } + + CK_TILE_HOST void stop(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipEventRecord(stop_evt, s)); + HIP_CHECK_ERROR(hipEventSynchronize(stop_evt)); + } + // return in ms + CK_TILE_HOST float duration() const + { + float ms = 0; + HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt)); + return ms; + } + + private: + hipEvent_t start_evt, stop_evt; +}; + +struct cpu_timer +{ + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void start(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + start_tick = std::chrono::high_resolution_clock::now(); + } + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void stop(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + stop_tick = std::chrono::high_resolution_clock::now(); + } + // return in ms + CK_TILE_HOST float duration() const + { + double sec = + std::chrono::duration_cast>(stop_tick - start_tick) + .count(); + return static_cast(sec * 1e3); + } + + private: + std::chrono::time_point start_tick; + std::chrono::time_point stop_tick; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp index 9c6c35390..c2fdaf3a1 100644 --- a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -23,13 +23,13 @@ VERTICAL: [0] 1 2 3 4 5 [0] 1 2 3 4 5 -TOP_LEFT: +TOP_LEFT(but negative): [0] 1 2 3 4 5 1 [0] 1 2 3 4 2 1 [0] 1 2 3 3 2 1 [0] 1 2 -FROM_BOTTOM_RIGHT: +FROM_BOTTOM_RIGHT(but negative): 2 1 [0] 1 2 3 3 2 1 [0] 1 2 4 3 2 1 [0] 1 @@ -54,7 +54,7 @@ struct Alibi index_t x_total_, AlibiMode mode_ = AlibiMode::VERTICAL) { - slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; shift_left_up = [&]() { if(RowMajor) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 10ce7395a..9992d56ea 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -76,7 +76,7 @@ struct FmhaFwdKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + @@ -702,7 +702,7 @@ struct FmhaFwdKernel else { return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; } } else diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index 52f458c72..e40b00668 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + static constexpr const char* name = "shb"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * @@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner } }; +template +using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner; + +template +struct FmhaFwdTilePartitioner_HBS +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + static constexpr const char* name = "hbs"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1)); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + } // namespace ck_tile diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp index e295ec454..4e13225dd 100644 --- a/test/position_embedding/position_embedding.cpp +++ b/test/position_embedding/position_embedding.cpp @@ -131,74 +131,74 @@ int main() 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); - - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); - - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); - - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); - - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); - - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, -- GitLab From 80db62f08d288f5d451753e7fc2aa4d602382892 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 28 May 2024 12:04:22 -0500 Subject: [PATCH 35/96] add f8 gemm multiD with both row/col wise scale (#1300) * add f8 gemm with multiD for both row/col wise * change compute_type to fp8 * changed tuning parameters in the example * add rcr example --- .../65_gemm_multiply_multiply/CMakeLists.txt | 1 + .../gemm_multiply_multiply_xdl_fp16.cpp | 274 +++ ...hread_group_tensor_slice_transfer_v7r3.hpp | 220 ++ ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 730 ++++++ .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 2082 +++++++++++++++++ .../threadwise_tensor_slice_transfer_v7r3.hpp | 648 +++++ 6 files changed, 3955 insertions(+) create mode 100644 example/65_gemm_multiply_multiply/CMakeLists.txt create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt new file mode 100644 index 000000000..f3594d153 --- /dev/null +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_multiply_multiply_xdl_fp16 gemm_multiply_multiply_xdl_fp16.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp new file mode 100644 index 000000000..b0e75a559 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using B0DataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 000000000..46d0c6ac2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r3 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 000000000..2275d8364 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,730 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 000000000..ea074144b --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck -- GitLab From 66de8a02baca937311f32ed072e26a5bd15093ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 28 May 2024 11:36:09 -0700 Subject: [PATCH 36/96] Bump rocm-docs-core from 1.1.3 to 1.2.0 in /docs/sphinx (#1311) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 1.1.3 to 1.2.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v1.1.3...v1.2.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 9473b4aba..06bb9365f 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.1.3 +rocm-docs-core==1.2.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 8941877ae..0883a3355 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.1.3 +rocm-docs-core==1.2.0 # via -r requirements.in six==1.16.0 # via -- GitLab From 34f3dfdd619b32a597acb678c0b8a50624434dd8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 28 May 2024 12:36:06 -0700 Subject: [PATCH 37/96] Build CK library for all supported targets. (#1312) * test library build for all supported targets * increase the number of threads to build lib in CI to 64 --- Jenkinsfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 75800bfc9..e8fd0c3ce 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -911,9 +911,8 @@ pipeline { execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \ -D INSTANCES_ONLY=ON \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) -- GitLab From 6fb1f4e03fef8a80ae8b5f139b9d4750e2f1a972 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Sat, 1 Jun 2024 00:46:41 -0500 Subject: [PATCH 38/96] Post-merge fix of PR 1300 (#1313) * add f8 gemm with multiD for both row/col wise * change compute_type to fp8 * changed tuning parameters in the example * add rcr example * post-merge fix * fix * reduce init range --- .../gemm_multiply_multiply_xdl_fp16.cpp | 12 ++++++------ .../device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp index b0e75a559..c584ff20c 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp @@ -59,7 +59,7 @@ struct MultiplyMultiply { const float x0_f = c * d0 * d1; - e = ck::type_convert(x0_f); + e = ck::type_convert(x0_f); } }; @@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ck::index_t K = 4096; ck::index_t StrideA = K; - ck::index_t StrideB = N; + ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; @@ -164,10 +164,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 2275d8364..c2b5317dd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD -struct GridwiseGemm_xdl_cshuffle_v3 +struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), -- GitLab From 3fa7e2a6c4ff1834a8c9bc6e89de776ec3192f5b Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:07:30 -0700 Subject: [PATCH 39/96] disable the hipTensor test by default, only run once daily (#1321) --- Jenkinsfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index e8fd0c3ce..855fe8dff 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -652,8 +652,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION= - 0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT= +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1; + 0 21 * * * % ROCMVERSION=6.1;hipTensor_test=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" @@ -701,8 +701,8 @@ pipeline { description: "Select whether to build DL kernels (default: OFF)") booleanParam( name: "hipTensor_test", - defaultValue: true, - description: "Use the CK build to verify hipTensor build and tests (default: ON)") + defaultValue: false, + description: "Use the CK build to verify hipTensor build and tests (default: OFF)") string( name: 'hipTensor_branch', defaultValue: 'mainline', -- GitLab From 76827d82ca89ddd78be5be86158e7f15b2c11e14 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Jun 2024 22:41:56 -0700 Subject: [PATCH 40/96] Bump rocm-docs-core from 1.2.0 to 1.2.1 in /docs/sphinx (#1322) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 1.2.0 to 1.2.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v1.2.0...v1.2.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 06bb9365f..6ab8e14dd 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.2.0 +rocm-docs-core==1.2.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 0883a3355..868c0044b 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.2.0 +rocm-docs-core==1.2.1 # via -r requirements.in six==1.16.0 # via -- GitLab From 2cab8d39e3c638d890774e43cc4ba6bcb982c54a Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Wed, 5 Jun 2024 02:12:45 +0800 Subject: [PATCH 41/96] CK Tile FA Training kernels (#1286) * FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm --------- Co-authored-by: danyao12 Co-authored-by: carlushuang Co-authored-by: rocking Co-authored-by: Po Yen Chen Co-authored-by: Jing Zhang --- example/ck_tile/01_fmha/CMakeLists.txt | 39 +- example/ck_tile/01_fmha/fmha_bwd.cpp | 932 +++++++++++ example/ck_tile/01_fmha/fmha_bwd.hpp | 359 +++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 135 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 114 +- example/ck_tile/01_fmha/generate.py | 702 +++++++- .../ck_tile/01_fmha/script/benchmark_bwd.sh | 21 + .../script/{benchmark.sh => benchmark_fwd.sh} | 0 .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 33 + .../{smoke_test.sh => smoke_test_fwd.sh} | 22 +- include/ck_tile/core.hpp | 3 + .../core/arch/amd_buffer_addressing.hpp | 74 +- .../core/arch/generic_memory_space_atomic.hpp | 175 ++ include/ck_tile/core/numeric/vector_type.hpp | 11 +- include/ck_tile/core/tensor/buffer_view.hpp | 13 +- include/ck_tile/core/tensor/store_tile.hpp | 2 +- include/ck_tile/core/tensor/tensor_view.hpp | 30 +- .../ck_tile/core/tensor/tile_distribution.hpp | 1 + include/ck_tile/core/tensor/tile_window.hpp | 60 + include/ck_tile/core/tensor/update_tile.hpp | 55 + include/ck_tile/core/utility/philox_rand.hpp | 89 ++ include/ck_tile/host.hpp | 1 + include/ck_tile/host/host_tensor.hpp | 36 +- .../reference/reference_batched_dropout.hpp | 33 + include/ck_tile/ops/fmha.hpp | 14 + .../ck_tile/ops/fmha/block/block_dropout.hpp | 329 ++++ .../ck_tile/ops/fmha/block/block_masking.hpp | 72 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 1421 +++++++++++++++++ .../fmha/kernel/fmha_bwd_tile_partitioner.hpp | 54 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 180 ++- .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 2 +- .../fmha/pipeline/block_fmha_bwd_dot_do_o.hpp | 95 ++ ...block_fmha_bwd_dot_do_o_default_policy.hpp | 20 + ...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp | 848 ++++++++++ ...k_dv_pipeline_ks_kts_vr_default_policy.hpp | 20 + ...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp | 821 ++++++++++ ...dq_dk_dv_pipeline_ks_vr_default_policy.hpp | 20 + ...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp | 692 ++++++++ ...v_pipeline_qs_ks_vr_dos_default_policy.hpp | 20 + ...block_fmha_bwd_pipeline_default_policy.hpp | 1343 ++++++++++++++++ .../pipeline/block_fmha_bwd_pipeline_enum.hpp | 16 + .../block_fmha_bwd_pipeline_problem.hpp | 91 ++ .../pipeline/block_fmha_pipeline_problem.hpp | 29 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 54 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 59 +- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 45 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 29 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 46 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 49 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 +- include/ck_tile/ops/gemm.hpp | 7 +- .../block_gemm_areg_bgmem_creg_problem.hpp | 25 - .../block/block_gemm_areg_bgmem_creg_v1.hpp | 4 +- ...gemm_areg_bgmem_creg_v1_default_policy.hpp | 2 +- .../block/block_gemm_areg_bsmem_creg_v1.hpp | 161 +- ..._gemm_areg_bsmem_creg_v1_custom_policy.hpp | 2 +- ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 2 +- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 2 +- ..._gemm_areg_bsmem_creg_v2_custom_policy.hpp | 2 +- ...gemm_areg_bsmem_creg_v2_default_policy.hpp | 2 +- .../block/block_gemm_asmem_breg_creg_v1.hpp | 228 +++ ..._gemm_asmem_breg_creg_v1_custom_policy.hpp | 36 + ...gemm_asmem_breg_creg_v1_default_policy.hpp | 56 + .../block_gemm_asmem_bsmem_creg_problem.hpp | 26 - .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 2 +- ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 2 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 2 +- ...reg_problem.hpp => block_gemm_problem.hpp} | 4 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 88 + 70 files changed, 9506 insertions(+), 482 deletions(-) create mode 100644 example/ck_tile/01_fmha/fmha_bwd.cpp create mode 100644 example/ck_tile/01_fmha/fmha_bwd.hpp create mode 100644 example/ck_tile/01_fmha/script/benchmark_bwd.sh rename example/ck_tile/01_fmha/script/{benchmark.sh => benchmark_fwd.sh} (100%) create mode 100644 example/ck_tile/01_fmha/script/smoke_test_bwd.sh rename example/ck_tile/01_fmha/script/{smoke_test.sh => smoke_test_fwd.sh} (57%) create mode 100644 include/ck_tile/core/arch/generic_memory_space_atomic.hpp create mode 100644 include/ck_tile/core/tensor/update_tile.hpp create mode 100644 include/ck_tile/core/utility/philox_rand.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_dropout.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_dropout.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp delete mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp delete mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp rename include/ck_tile/ops/gemm/block/{block_gemm_areg_bsmem_creg_problem.hpp => block_gemm_problem.hpp} (88%) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 85d25c63d..e324f85ed 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,17 +1,29 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt + --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) -# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") @@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_BWD}") +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) @@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) +set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp new file mode 100644 index 000000000..b1249b5ed --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -0,0 +1,932 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_bwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k < 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + if(hdim_q % 2 != 0 || hdim_v % 2 != 0) + { + std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; + return false; + } + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) + { + std::cerr << "dbias only exists when bias type is elementwise" << std::endl; + return false; + } + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaBwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using GemmDataType = typename TypeConfig::GemmDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using AccDataType = typename TypeConfig::AccDataType; + using DDataType = typename TypeConfig::DDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using ODataType = typename TypeConfig::ODataType; + using OGradDataType = typename TypeConfig::OGradDataType; + using QGradDataType = typename TypeConfig::QGradDataType; + using KGradDataType = typename TypeConfig::KGradDataType; + using VGradDataType = typename TypeConfig::VGradDataType; + using BiasGradDataType = typename TypeConfig::BiasGradDataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(3) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + static_cast(2) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * real_seqlen_k * hdim_v + + sizeof(ODataType) * real_seqlen_q * hdim_v + + sizeof(OGradDataType) * real_seqlen_q * hdim_v + + sizeof(QGradDataType) * real_seqlen_q * hdim_q + + sizeof(KGradDataType) * real_seqlen_k * hdim_q + + sizeof(VGradDataType) * real_seqlen_k * hdim_v + + sizeof(LSEDataType) * real_seqlen_q); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor lse_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor d_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor dq_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor dk_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor dv_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor do_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor dbias_host( + use_dbias + ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + if(init_method == 0) + { + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 2) + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(do_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + do_buf.ToDevice(do_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask + << std::flush; + + auto fmha_traits = fmha_bwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck_tile::index_t stride_bias = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_bias = 0; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; + const ck_tile::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + + return fmha_bwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + do_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + dq_buf.GetDeviceBuffer(), + dk_buf.GetDeviceBuffer(), + dv_buf.GetDeviceBuffer(), + dbias_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, + stride_o, + stride_randval, + stride_do, + stride_dk, + stride_dv, + stride_dbias, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + nhead_stride_randval, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dbias, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + batch_stride_randval, + batch_stride_do, + batch_stride_lsed, + batch_stride_dk, + batch_stride_dv, + batch_stride_dbias, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + p_undrop, + s_randval, + {drop_seed, drop_offset}}; + }(); + + float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + bool pass = true; + + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; + + randval_buf.FromDevice(randval_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + ck_tile::HostTensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + ck_tile::HostTensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + + // reference + // S = scale * Q * K^T + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + + if(p_drop > 0) + { + p_hp_host_ref.ForEach( + [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + else + { + p_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + // O = P * V + ck_tile::reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + + // clang-format off + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); + // clang-format on + + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } + } + + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + fmha_bwd(fmha_traits, fmha_args, stream_config_v); + + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + ck_tile::HostTensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + ck_tile::HostTensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + + // clang-format off + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + ck_tile::reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + + if(p_drop > 0) + { + ck_tile::reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } + + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + auto idx_gmo = idx_gmn; + idx_gmo[2] = o; + do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * + ck_tile::type_convert(o_host_refs[wb](idx_gmo)); + } + self(idx_gmn) = ck_tile::type_convert( + p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); + }); + + if(use_dbias) + { + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + ck_tile::reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + ck_tile::reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + ck_tile::reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + + // clang-format off + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + } + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); + + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp new file mode 100644 index 000000000..0c6b46895 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include "bias.hpp" +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_lsed, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 91fc07d83..5f887f065 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -110,6 +110,9 @@ auto create_args(int argc, char* argv[]) "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -128,26 +131,11 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else if(init_method == "nf") - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else - { - double rtol = 3e-3; - double atol = 3e-3; - return ck_tile::make_tuple(rtol, atol); - } + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); } template <> @@ -250,6 +238,21 @@ bool run(const ck_tile::ArgParser& arg_parser) mask_info mask = mask_info::decode( arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); if(*seed == 0) @@ -274,21 +277,23 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = FmhaFwdTypeConfig; - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -300,6 +305,11 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_q = real_seqlen_q; } + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); @@ -353,12 +363,16 @@ bool run(const ck_tile::ArgParser& arg_parser) // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); @@ -434,6 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); @@ -463,8 +478,8 @@ bool run(const ck_tile::ArgParser& arg_parser) << (seqlen_kpads[0] < 0 ? "" : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias - << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout - << std::flush; + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -474,6 +489,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, bias.type, lse, + p_drop > 0.0f, squant}; auto p_compute_element_func = [&]() { @@ -505,8 +521,9 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -518,21 +535,24 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); - const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() : bias_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -554,22 +574,28 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_v, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, + stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_bias, + batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, - static_cast(mask.type)}; + static_cast(mask.type), + p_drop, + s_randval, + {drop_seed, drop_offset}}; }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); @@ -596,6 +622,11 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; bool pass = true; @@ -771,6 +802,17 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, p_host_ref, p_compute_element_func); } + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, @@ -804,9 +846,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b, idx[0], idx[1] + query_offset); - }); + lse_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); bool lse_pass = ck_tile::check_err(lse_host_result, lse_host_ref, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index fb3907fec..3594f61db 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; }; struct FmhaMasks @@ -88,6 +92,7 @@ struct fmha_fwd_args const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -108,22 +113,28 @@ struct fmha_fwd_args ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; }; template @@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, @@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqlen_k_ptr, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_lse, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, + args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); @@ -222,6 +249,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; @@ -264,6 +293,7 @@ struct fmha_fwd_traits mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; + bool has_dropout; bool do_fp8_static_quant; // TODO: padding check is inside this api }; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index f0180d6db..e0b4b6559 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -83,7 +83,6 @@ TILE_PARTITIONER_MAP = { "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", } -DIRECTIONS = ["fwd"] GEN_DIR = "" # in Cmake, have to generate files in same folder FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -111,8 +110,10 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_dpad}, {F_dvpad}, {F_bias}, + false, {F_lse}, - {F_squant}, + {F_dropout}, + {F_squant}, {F_occupancy}>; using fmha_mask_{F_idx} = {F_mask}; @@ -123,6 +124,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::SaccDataType, typename FmhaFwdTypeConfig::SMPLComputeDataType, typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, @@ -146,7 +148,7 @@ using fmha_kernel_{F_idx} = fmha_epilogue_{F_idx}>; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -191,9 +193,9 @@ MASK_SIMPLIFIED_CHECK_MAP = { "s_mask" : "t.mask_type != mask_enum::no_mask", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -233,6 +235,7 @@ class FmhaFwdApiTrait: mask : str bias : str # lse : str # + dropout : str squant : str # spad : str skpad : str @@ -242,7 +245,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -299,6 +302,7 @@ class FmhaFwdPipeline: F_dvpad : str # F_bias : str # true/false F_lse : str # + F_dropout : str # F_squant : str # F_mask : str # value from MASK_MAP @@ -321,6 +325,7 @@ class FmhaFwdPipeline: else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' if self.F_squant == 't' : n += '_squant' return n @@ -351,7 +356,7 @@ class FmhaFwdApiPool: inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -365,7 +370,7 @@ class FmhaFwdApiPool: @dataclass class FmhaFwdTileSize: F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along qk seqlen + F_bn0 : int # tile size along k seqlen F_bk0 : int # tile size along qk gemm unroll F_bn1 : int # tile size along v head_dim F_bk1 : int # tile size along kv gemm unroll @@ -424,9 +429,10 @@ class FmhaFwdKernel: F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -461,6 +467,7 @@ class FmhaFwdKernel: mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, @@ -489,7 +496,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -500,26 +507,26 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse kernels + # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) else: assert False return pipelines @@ -527,7 +534,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) if d == None: continue @@ -551,44 +558,660 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) return (api_pool, gen) -def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_dk_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, false>>; + +using fmha_bwd_dv_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, false>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_dk_epilogue_{F_idx}, + fmha_bwd_dv_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'bwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: (autogen_dir / kernel.filename).write_text(kernel.template) -def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_api(api_pool, output_dir) + if direction == 'fwd': + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + if direction == 'fwd': + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + else: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", description="gen api for CK fmha kernel", ) + parser.add_argument( + "-d", + "--direction", + default='fwd', + choices=['fwd', 'bwd'], + required=False, + help="choose the direction of kernels(default: fwd)" + ) parser.add_argument( "-o", "--output_dir", @@ -623,11 +1246,12 @@ if __name__ == "__main__": default=0, required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim" + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration" ) args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) + list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) + write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh new file mode 100644 index 000000000..7591f5442 --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh similarity index 100% rename from example/ck_tile/01_fmha/script/benchmark.sh rename to example/ck_tile/01_fmha/script/benchmark_fwd.sh diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh new file mode 100644 index 000000000..9fe795471 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -0,0 +1,33 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1' + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do +for mode in 0 1 ; do +for bias in "n" "e" "a"; do +for dbias in 0 1 ; do +for p_drop in 0.0 0.2; do + +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh similarity index 57% rename from example/ck_tile/01_fmha/script/smoke_test.sh rename to example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 21f679e11..63813e079 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -17,17 +17,18 @@ for perm in 0 1 ; do for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do -for bias in "n" "e" "a"; do +for bias in "n" "e" "a" ; do +for p_drop in 0.0 0.2; do -# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done @@ -37,6 +38,7 @@ done done done done +done for perm in 0 1 ; do for bias in "n" "e" "a" ; do diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb19c9154..bb490cce4 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" @@ -47,10 +48,12 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index ac2f0cab9..9c6e85f01 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -783,6 +783,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); +// buffer store ui16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, @@ -1353,7 +1375,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 @@ -1492,6 +1517,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d static_cast(coherence)); } } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(uint16_t), + static_cast(coherence)); + } + } else { using r_t = thread_buffer; @@ -1609,7 +1677,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_th { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp new file mode 100644 index 000000000..6212db916 --- /dev/null +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" + +namespace ck_tile { + +CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) +{ + return type_convert(type_convert(a) + type_convert(b)); +} + +CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) +{ + bf16x2_t rtn; + rtn[0] = add_bf16_t(a[0], b[0]); + rtn[1] = add_bf16_t(a[1], b[1]); + return rtn; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); + +template <> +CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) +{ + union U32BF162_ADDR + { + uint32_t* u32_a; + bf16x2_t* bf162_a; + }; + + union U32BF162 + { + uint32_t u32; + bf16x2_t bf162; + }; + + U32BF162_ADDR dword_addr; + U32BF162 cur_v; + U32BF162 new_; + uint32_t old_v, new_v; + dword_addr.bf162_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf162 = add_bf16x2_t(cur_v.bf162, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template +CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + return atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 2) + { + atomic_add(c_style_pointer_cast(p_dst), bit_cast(x)); + } + else if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } +} + +template +CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicMax(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicMax(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 85d9be1c9..c23c12f29 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); +// ui8 +// using uint8_t +using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); +using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); +using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); +using uint8x16_t = uint8_t __attribute((ext_vector_type(16))); +using uint8x32_t = uint8_t __attribute((ext_vector_type(32))); +using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); + #if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 96b38241c..ffe8f7a4f 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" @@ -507,10 +508,10 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } @@ -518,7 +519,7 @@ struct buffer_view(c_style_pointer_cast(&p_data_[i]), x); + atomic_add_g, t_per_x>(&p_data_[i], x); } } } @@ -547,16 +548,16 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } else if(is_valid_element) { - atomic_max(c_style_pointer_cast(&p_data_[i]), x); + atomic_max_g, t_per_x>(&p_data_[i], x); } } diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index c12ad883d..2efc65701 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index e37bd806d..656309532 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -16,7 +16,9 @@ namespace ck_tile { -template +template struct tensor_view { using buffer_view = remove_reference_t; @@ -24,6 +26,7 @@ struct tensor_view using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + static constexpr auto DstInMemOp = DstInMemOp_; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -140,6 +143,23 @@ struct tensor_view x); } + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template update( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); @@ -178,6 +198,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, } template (p, desc.get_element_space_size()); - return tensor_view{buffer_view, desc}; + return tensor_view{buffer_view, desc}; } template >{ - old_tensor_view.buf_, new_desc}; + return tensor_view, + remove_cvref_t::DstInMemOp>{old_tensor_view.buf_, new_desc}; } template + CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp new file mode 100644 index 000000000..fbce7c408 --- /dev/null +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.update(dstr_tensor); +} + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.update(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp new file mode 100644 index 000000000..c49f44ae4 --- /dev/null +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh +class philox +{ + public: + CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_) + : seed(reinterpret_cast(seed_)) + { + + ull2* tmp = reinterpret_cast(&counter); + tmp->x = offset_; + } + + CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const + { + + uint4 counter_ = counter; + ull2* tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + + uint2 key_ = seed; +// 7-round philox +#pragma unroll + for(int i = 0; i < 6; i++) + { + counter_ = philox_single_round(counter_, key_); + key_.x += kPhilox10A; + key_.y += kPhilox10B; + } + uint4 output = philox_single_round(counter_, key_); + return output; + } + + CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out, + const unsigned long long subsequence) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32_t* out_tmp = reinterpret_cast(&out[0]); + + out_tmp[0] = tmp_ph.x; + out_tmp[1] = tmp_ph.y; + out_tmp[2] = tmp_ph.z; + out_tmp[3] = tmp_ph.w; + } + + private: + struct ull2 + { + uint64_t x; + uint64_t y; + }; + uint4 counter; + const uint2 seed; + + CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const + { + uint2* res; + unsigned long long tmp; + tmp = static_cast(a) * b; + res = reinterpret_cast(&tmp); + return *res; + } + + CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const + { + + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 98a3bb7d7..09030fa6d 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,6 +11,7 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" +#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index cd0dc3825..43405ee69 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,7 +156,7 @@ struct HostTensorDescriptor } const std::vector& get_lengths() const { return mLens; } - const std::vector& GetStrides() const { return mStrides; } + const std::vector& get_strides() const { return mStrides; } template std::size_t GetOffsetFromMultiIndex(Is... is) const @@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old for(std::size_t i = 0; i < a.get_num_of_dimension(); i++) { new_lengths[i] = a.get_lengths()[new2old[i]]; - new_strides[i] = a.GetStrides()[new2old[i]]; + new_strides[i] = a.get_strides()[new2old[i]]; } return HostTensorDescriptor(new_lengths, new_strides); @@ -327,7 +327,7 @@ struct HostTensor decltype(auto) get_lengths() const { return mDesc.get_lengths(); } - decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + decltype(auto) get_strides() const { return mDesc.get_strides(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } @@ -481,6 +481,34 @@ struct HostTensor return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } + HostTensor transpose(std::vector axes = {}) const + { + if(axes.empty()) + { + axes.resize(this->get_num_of_dimension()); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if(axes.size() != mDesc.get_num_of_dimension()) + { + throw std::runtime_error( + "HostTensor::transpose(): size of axes must match tensor dimension"); + } + std::vector tlengths, tstrides; + for(const auto& axis : axes) + { + tlengths.push_back(get_lengths()[axis]); + tstrides.push_back(get_strides()[axis]); + } + HostTensor ret(*this); + ret.mDesc = HostTensorDescriptor(tlengths, tstrides); + return ret; + } + + HostTensor transpose(std::vector axes = {}) + { + return const_cast const*>(this)->transpose(axes); + } + typename Data::iterator begin() { return mData.begin(); } typename Data::iterator end() { return mData.end(); } diff --git a/include/ck_tile/host/reference/reference_batched_dropout.hpp b/include/ck_tile/host/reference/reference_batched_dropout.hpp new file mode 100644 index 000000000..242101bf4 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_dropout.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_dropout(HostTensor& in_out_b_m_n, + const HostTensor& randval_b_m_n, + const uint8_t& p_undrop_in_uint8_t, + const float scale) +{ + const int N = in_out_b_m_n.mDesc.get_lengths()[2]; + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + float tmp = ck_tile::type_convert(in_out_b_m_n(batch, m, n)) * scale; + in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t + ? ck_tile::type_convert(tmp) + : DataType(0); + } + }; + + make_ParallelTensorFunctor( + f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 1122bf87b..568486830 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -4,10 +4,24 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp new file mode 100644 index 000000000..1f0fe2bd6 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + else + { + return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + }(); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + + auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx1 = + tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + } + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval.kThreadElementSpaceSize == 16); + + const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + int block_row_start = (start_m0_idx / WG::kM) + i_m0; + int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + randval(r_idx) = random_uint8_t[i_random_idx++]; + constexpr auto p_idx0 = + tile_distributed_index{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + ck_tile::philox ph; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 7fb1c19b5..f43de4573 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,6 +141,36 @@ struct GenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -160,14 +190,14 @@ struct GenericAttentionMask } else { - return i_x >= x_end; + return i_x >= x_end || i_y >= y_total; } } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -269,6 +299,36 @@ struct SimplifiedGenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -283,13 +343,13 @@ struct SimplifiedGenericAttentionMask index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end; + return i_x < x_start || i_x >= x_end || i_y >= y_total; } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -361,6 +421,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size, { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp new file mode 100644 index 000000000..e713cefbd --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -0,0 +1,1421 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q] +// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v] +// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v]) +// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q]) +// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k] +// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1] +// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1] + +namespace ck_tile { + +template +struct FmhaBwdDQDKDVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using KGradEpiloguePipeline = ck_tile::remove_cvref_t; + using VGradEpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using GemmDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + using QGradDataType = ck_tile::remove_cvref_t; + using KGradDataType = ck_tile::remove_cvref_t; + using VGradDataType = ck_tile::remove_cvref_t; + using BiasGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaBwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lse_ptr; + const void* do_ptr; + const void* d_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_head_q; + ck_tile::index_t nhead_ratio_qk; + float raw_scale; +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale; +#endif + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + + ck_tile::index_t batch_stride_lsed; + }; + + struct FmhaBwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaBwdBatchModeBiasKargs : FmhaBwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaBwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaBwdCommonBiasGradKargs + { + void* dbias_ptr = nullptr; + ck_tile::index_t stride_dbias = 0; + ck_tile::index_t nhead_stride_dbias = 0; + }; + + struct FmhaBwdBatchModeBiasGradKargs : FmhaBwdCommonBiasGradKargs + { + ck_tile::index_t batch_stride_dbias = 0; + }; + + struct FmhaBwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaBwdCommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset, + const float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaBwdBatchModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + }; + + struct FmhaBwdGroupModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dk, + ck_tile::index_t batch_stride_dv, + ck_tile::index_t batch_stride_dbias, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_do, + batch_stride_dk, + batch_stride_dv}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + kargs.batch_stride_dbias = batch_stride_dbias; + } + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), + KGradEpiloguePipeline::GetSmemSize(), + VGradEpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_lsed = 0; + long_index_t batch_offset_dk = 0; + long_index_t batch_offset_dv = 0; + long_index_t batch_offset_dbias = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = query_start * kargs.stride_dbias; + } + else + { + batch_offset_dbias = key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_k <= i_n0) + { + return; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = static_cast(i_batch) * kargs.batch_stride_dbias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + const LSEDataType* lse_ptr = reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_dk; + VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_dv; + + // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto qt_dram_naive = + transform_tensor_view(q_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto qt_dram = [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto kt_dram_naive = + transform_tensor_view(k_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto kt_dram = [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kVLoadOnce) + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + lse_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + const auto do_dram = [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto dot_dram_naive = + transform_tensor_view(do_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto dot_dram = [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto dq_dram = [&]() { + const auto dq_dram_naive = make_naive_tensor_view( + dq_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto qt_dram_window = + make_tile_window(qt_dram, + [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto kt_dram_window = + make_tile_window(kt_dram, + [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, i_n0}); + + auto v_dram_window = make_tile_window( + v_dram, + [&]() { + if constexpr(FmhaPipeline::kVLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto do_dram_window = make_tile_window( + do_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto dot_dram_window = + make_tile_window(dot_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto dq_dram_window = make_tile_window( + dq_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto lse_dram_window = + make_tile_window(lse_dram, make_tuple(number{}), {0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dbias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(kHasBiasGrad) + { + BiasGradDataType* dbias_ptr = + reinterpret_cast(kargs.dbias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dbias + + batch_offset_dbias; + + auto dbias_dram = [&]() { + const auto dbias_dram_naive = + make_naive_tensor_view( + dbias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_dbias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(dbias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + AccDataType slope = *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + // dropout + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + scale_rp_undrop = kargs.scale_rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + qt_dram_window, + k_dram_window, + kt_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + dot_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dbias_dram_window, + mask, + position_encoding, + kargs.raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + kargs.scale, +#endif + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); + + auto dk_dram = [&]() { + const auto dk_dram_naive = make_naive_tensor_view( + dk_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_dk, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dk_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dv_dram = [&]() { + const auto dv_dram_naive = make_naive_tensor_view( + dv_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_dv, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dv_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dk_dram_window = make_tile_window( + dk_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + auto dv_dram_window = make_tile_window( + dv_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); + VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + } +}; + +template +struct FmhaBwdOGradDotOKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; + static constexpr ck_tile::index_t kM0 = kBlockSize; + static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim; + + using DDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdOGradDotOCommonKargs + { + const void* o_ptr; + const void* do_ptr; + void* d_ptr; + + float p_undrop; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t stride_do; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_d; + ck_tile::index_t batch_stride_d; + }; + + struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs + { + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + seqlen_q, + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + batch_stride_do, + batch_stride_o}; + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + -1, // seqlen will be updated by another pointer + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + reinterpret_cast(seqstart_q_ptr)}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + + long_index_t batch_offset_o = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_d = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.stride_o; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + } + + // for simplicity, batch stride we just modify the pointer + const ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_d + + batch_offset_d; + + // O/dO/D DRAM and DRAM window + const auto o_dram = [&]() { + auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto do_dram = [&]() { + auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + return pad_tensor_view(do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto do_dram_window = + make_tile_window(do_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {i_m0}); + + FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp new file mode 100644 index 000000000..bc875b8e5 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaBwdTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +template +struct FmhaBwdOGradDotOTilePartitioner +{ + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9992d56ea..6624bf1a9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,11 +9,11 @@ #include #include -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] -// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] namespace ck_tile { @@ -32,6 +32,8 @@ struct FmhaFwdKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; @@ -45,6 +47,7 @@ struct FmhaFwdKernel static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -84,7 +87,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -111,6 +114,7 @@ struct FmhaFwdKernel ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; + ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; @@ -163,11 +167,35 @@ struct FmhaFwdKernel { void* lse_ptr = nullptr; ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + struct FmhaFwdCommonDropoutKargs { - ck_tile::index_t batch_stride_lse = 0; + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; }; struct FmhaFwdBatchModeKargs @@ -178,8 +206,9 @@ struct FmhaFwdKernel FmhaFwdAlibiKargs, FmhaFwdEmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -196,7 +225,8 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -211,12 +241,14 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -225,22 +257,28 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -250,6 +288,7 @@ struct FmhaFwdKernel seqlen_k, hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -268,6 +307,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v, @@ -302,6 +342,15 @@ struct FmhaFwdKernel kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -312,6 +361,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -319,6 +369,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -327,16 +378,22 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -346,6 +403,7 @@ struct FmhaFwdKernel -1, // hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -364,6 +422,7 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -389,12 +448,21 @@ struct FmhaFwdKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -426,12 +494,13 @@ struct FmhaFwdKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -455,7 +524,11 @@ struct FmhaFwdKernel } if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; } batch_offset_o = query_start * kargs.stride_o; @@ -493,6 +566,11 @@ struct FmhaFwdKernel { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -666,6 +744,62 @@ struct FmhaFwdKernel } }(); + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -723,6 +857,7 @@ struct FmhaFwdKernel identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func + randval_dram_window, lse_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func @@ -731,7 +866,8 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } else { @@ -739,11 +875,13 @@ struct FmhaFwdKernel k_dram_window, v_dram_window, bias_dram_window, + randval_dram_window, lse_dram_window, mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index e40b00668..2dca84b78 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp new file mode 100644 index 000000000..f18993703 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdOGradDotO +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kVHeaddim = Problem::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + template + CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + DDramBlockWindowTmp& d_dram_block_window_tmp, + float p_undrop) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == + OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + auto o_dram_window = + make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(), + o_dram_block_window_tmp.get_window_lengths(), + o_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreODramTileDistribution()); + + auto o = load_tile(o_dram_window); + + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + do_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreOGradDramTileDistribution()); + + auto do_ = load_tile(do_dram_window); + + // declare d + constexpr auto d_dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); + + auto d = make_static_distributed_tensor(d_dstr); + + clear_tile(d); // Initialize D + + constexpr auto o_spans = decltype(o)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + d(i_idx) += + (type_convert(o[i_j_idx]) * type_convert(do_[i_j_idx])); + }); + }); + + tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d); + + store_tile(d_dram_block_window_tmp, d); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp new file mode 100644 index 000000000..7843ab33a --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// These templates are not used here. +using BlockFmhaBwdOGradDotODefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp new file mode 100644 index 000000000..344456750 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -0,0 +1,848 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSKTSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = true; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_kts_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& kt_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto kt_lds = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto kt_dram_block_window = kt_dram_block_window_tmp; + + auto kt_dram_window = make_tile_window( + kt_dram_block_window.get_bottom_tensor_view(), + kt_dram_block_window.get_window_lengths(), + kt_dram_block_window.get_window_origin(), + Policy::template MakeKTDramTileDistribution()); // K^T DRAM tile window for + // load + + auto kt_block_tile = load_tile(kt_dram_window); + + auto kt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledKTRegBlockDescriptor()); + shuffle_tile(kt_shuffle_tmp, kt_block_tile); + + store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp new file mode 100644 index 000000000..a05fbf252 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k & k^t located in lds. +using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp new file mode 100644 index 000000000..dec421c1e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -0,0 +1,821 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp new file mode 100644 index 000000000..cc4e6304d --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k located in lds. +using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp new file mode 100644 index 000000000..97487bb71 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -0,0 +1,692 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = true; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = true; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "qs_ks_vr_dos"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + auto qt_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + auto dot_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + clear_tile(st_acc); // Initialize S^T + store_tile(q_lds_window, q_block_tile); // LDS write + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + }); + } + + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + block_sync_lds(); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + block_sync_lds(); + store_tile(do_lds_window, do_block_tile); // store the prefetch + + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence{}, sequence<(i_k1 + 1) * kK1, kN0>{}), + get_slice_tile(dot_lds_window, + sequence<0, i_k1 * kK1>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + clear_tile(dpt_acc); // Initialize PGrad^T + + static_for<0, k2_loops, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + get_slice_tile(do_lds_window, + sequence<0, i_k2 * kK2>{}, + sequence{}), + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + }); + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + block_sync_lds(); + const auto dst_gemm = cast_tile(dst); + + static_for<0, k3_loops, 1>{}([&](auto i_k3) { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence{}, sequence<(i_k3 + 1) * kK3, kN0>{}), + get_slice_tile(qt_lds_window, + sequence<0, i_k3 * kK3>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp new file mode 100644 index 000000000..ac81990e0 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, q & k & do located in lds. +using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp new file mode 100644 index 000000000..a013ee3d5 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -0,0 +1,1343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineDefaultPolicy +{ + static constexpr bool QLoadOnce = + QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once + static constexpr bool QTLoadOnce = + QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once + static constexpr bool KLoadOnce = + KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once + static constexpr bool KTLoadOnce = + KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once + static constexpr bool VLoadOnce = + VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once + static constexpr bool OGradLoadOnce = + OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once + static constexpr bool OGradTLoadOnce = + OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(VLoadOnce) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + else + { + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 32) + return 8; + else + return 4; + } + + // these are for lds + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + { + // TODO: this is for 3d layout + using BiasDataType = remove_cvref_t; + return 16 / sizeof(BiasDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() + { + // TODO: this is for 3d layout + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() + { + // TODO: this is for 3d layout + using GemmDataType = remove_cvref_t; + return 16 / sizeof(GemmDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return x_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + static_assert(PixelsPerRow % KPack == 0); + constexpr index_t NPerRow = PixelsPerRow / KPack; + static_assert(MNPerBlock % NPerRow == 0); + static_assert(KPerBlock % KPack == 0); + + constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + xt_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPack = GetSmemKPackV(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPack = GetSmemKPackSGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType); + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() + { + using KDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() + { + using QGradDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() + { + using BiasDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType); + constexpr index_t kKPack = GetSmemKPackBias(); + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kMPerBlock % kKPack == 0); + + constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto biast_lds_block_desc = transform_tensor_descriptor( + biast_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return biast_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(); + return smem_size_q; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() + { + constexpr index_t smem_size_qt = [&]() { + if constexpr(QLoadOnce && !QTLoadOnce) + return 0; + else + return sizeof(typename Problem::QDataType) * + MakeQTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_qt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * + MakeKLdsBlockDescriptor().get_element_space_size(); + return smem_size_k; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() + { + constexpr index_t smem_size_kt = [&]() { + if constexpr(KLoadOnce && !KTLoadOnce) + return 0; + else + return sizeof(typename Problem::KDataType) * + MakeKTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_kt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + constexpr index_t smem_size_v = [&]() { + if constexpr(VLoadOnce) + return 0; + else + return sizeof(typename Problem::VDataType) * + MakeVLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_v; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() + { + constexpr index_t smem_size_do = + sizeof(typename Problem::OGradDataType) * + MakeOGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_do; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() + { + constexpr index_t smem_size_dot = [&]() { + if constexpr(OGradLoadOnce && !OGradTLoadOnce) + return 0; + else + return sizeof(typename Problem::OGradDataType) * + MakeOGradTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_dot; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() + { + constexpr index_t smem_size_ds = + sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_ds; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() + { + constexpr index_t smem_size_bias = [&]() { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasTLdsBlockDescriptor().get_element_space_size(); + else + return 0; + }(); + return smem_size_bias; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_qt = GetSmemSizeQT(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_kt = GetSmemSizeKT(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_dot = GetSmemSizeOGradT(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); + + index_t smem_size = 0; + + if constexpr(QLoadOnce && OGradLoadOnce) + smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + + smem_size_transpose; // 1~4 & 10 + else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += smem_size_q + smem_size_qt + + max(smem_size_do, + smem_size_dot, + smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) + smem_size += smem_size_do + smem_size_dot + + max(smem_size_q, + smem_size_qt, + smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += max(smem_size_q, + smem_size_qt, + smem_size_do, + smem_size_dot, + smem_size_transpose); // 9/13 TODO: Multiple buffers strategy + + // 14/15 needs to be adjusted + if constexpr(KLoadOnce) + smem_size += (smem_size_k + smem_size_kt); // 1~13 + else + smem_size = + max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy + + return max(smem_size, smem_size_v); // 15 + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + + constexpr index_t K1 = GetAlignmentOGrad(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + static_assert(kMPerBlock == M0 * M1 * M2 * M3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<1, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBRegCRegV1CustomPolicy; + + return BlockGemmASmemBRegCRegV1{}; + } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + // { + // using BlockGemmProblem = + // BlockGemmPipelineProblem>; + // constexpr auto warp_gemm = []() { + // if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + // } + // else if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + // } + // }(); + + // using BlockGemmPolicy = + // BlockGemmASmemBSmemCRegV1CustomPolicy; + + // return BlockGemmASmemBSmemCRegV1{}; + // } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp new file mode 100644 index 000000000..a54a9fcb3 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaBwdPipelineEnum +{ + KSKTSVR = 0, + QSKSVROGradS, + KSVR, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp new file mode 100644 index 000000000..7b787e9f3 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +template +struct BlockFmhaBwdOGradDotOPipelineProblem +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kVHeaddim = kVHeaddim_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 159fb4074..1b72b6005 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -13,6 +13,7 @@ template struct BlockFmhaPipelineProblem { - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; @@ -47,6 +49,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 60650761d..06ce3a651 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -14,19 +15,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,6 +51,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -106,6 +109,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -125,6 +129,7 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -133,7 +138,8 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -240,6 +246,9 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -475,6 +484,12 @@ struct BlockFmhaPipelineQRKSVS }); }); + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + block_sync_lds(); if constexpr(std::is_same_v) { @@ -589,6 +604,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto @@ -596,11 +612,13 @@ struct BlockFmhaPipelineQRKSVS const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -610,6 +628,7 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -618,7 +637,8 @@ struct BlockFmhaPipelineQRKSVS mask, position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8a19deb02..9939a474b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -15,19 +16,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVSAsync { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -54,6 +56,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -118,6 +121,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -137,6 +141,7 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, @@ -145,7 +150,8 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { static_assert( std::is_same_v> && @@ -292,6 +298,9 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -558,6 +567,17 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -688,6 +708,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto @@ -695,11 +716,13 @@ struct BlockFmhaPipelineQRKSVSAsync const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -709,6 +732,7 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, @@ -717,7 +741,8 @@ struct BlockFmhaPipelineQRKSVSAsync mask, position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 80f40f815..f4767de0e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,19 +14,20 @@ namespace ck_tile { template struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -49,6 +50,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -106,20 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, PositionEncoding /*position_encoding*/, float scale_s, float descale_qk, float descale_sv, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& /*dropout*/) const // not supported { static_assert( std::is_same_v> && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index e12e76706..bc9ca93d0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,19 +13,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQSKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4fda6f008..12af81bb9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { // TODO: assume Q is in register // TODO: assume K/V has same data type @@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + if constexpr(AsyncCopyK) + { + return GetSmemSizeKV() + GetSmemSizeDropout(); + } + else + { + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + if constexpr(Problem::kHasDropout) + { + constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm(); + constexpr auto config = + decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + + return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); + } + else + { + return 0; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index d8a290b09..64a61e94d 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -43,4 +43,53 @@ struct TileFmhaShape ck_tile::tensor_layout::gemm::ColumnMajor>; }; +template +struct TileFmhaBwdShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + using Gemm2BlockWarps = remove_cvref_t; + using Gemm2WarpTile = remove_cvref_t; + using Gemm3BlockWarps = remove_cvref_t; + using Gemm3WarpTile = remove_cvref_t; + using Gemm4BlockWarps = remove_cvref_t; + using Gemm4WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + + static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) && + NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{})); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = + BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll + static constexpr index_t kK1 = + BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll + static constexpr index_t kK2 = + BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll + static constexpr index_t kK3 = + BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll + static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or + // K/K^T at once + static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline + // that need load V at once +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 6cb6449f1..973ffa9f8 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,7 +13,9 @@ template struct TileFmhaTraits @@ -23,9 +25,21 @@ struct TileFmhaTraits static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaBwdOGradDotOTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c7ebcf960..a89536e6e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -3,20 +3,21 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp deleted file mode 100644 index 1053c751a..000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -// Problem Description for BlockGemmARegBGmemCReg -template -struct BlockGemmARegBGmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index 7799bbe91..f097790ae 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem, + BlockGemmProblem, BlockGemmARegBSmemCRegV1DefaultPolicy>; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp index 4156398bd..0a17b0535 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index aac9c4f55..84883d6ed 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" namespace ck_tile { @@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1 }); } - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iNWarp = get_warp_id() % NWarp; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - // constrcut from A-block-tensor from A-Block-tensor-tmp - // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent - // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); - - a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // Construct C-Block-HostTensor - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - using AWarpDstr = typename WG::AWarpDstr; - using CWarpDstr = typename WG::CWarpDstr; - - using AWarpTensor = typename WG::AWarpTensor; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp index 779113d96..f998c67c9 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 807398926..9b10d435b 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 405d7f125..4a82702c1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp index 8bcd04b7b..20dcf2c27 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp index c17385b8e..e90500c28 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp new file mode 100644 index 000000000..65ce1a9b8 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmASmemBRegCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode); + + // constrcut from B-block-tensor from B-Block-tensor-tmp + // FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent + // distribution + auto b_block_tensor = + make_static_distributed_tensor(b_block_dstr); + + b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer(); + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp new file mode 100644 index 000000000..5a17578f6 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmASmemBRegCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp new file mode 100644 index 000000000..cd16f09c3 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBRegCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBRegCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp deleted file mode 100644 index ed772891a..000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -// Problem Description for BlockGemmASmemBSmemCRegV1 -template -struct BlockGemmASmemBSmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 40da16d82..ac4522170 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 319711088..2436457ec 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index fbb957727..f798d6e81 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp similarity index 88% rename from include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp rename to include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index 7a0390a8a..d8f66c81c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -7,13 +7,13 @@ namespace ck_tile { -// Problem Description for BlockGemmARegBSmemCReg +// Problem Description for BlockGemm template -struct BlockGemmARegBSmemCRegProblem +struct BlockGemmProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index dfc63f04c..5b4419b79 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; -using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; @@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; -using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 71c59bbd1..fd5b004d3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB } }; +template +struct WarpGemmAtrributeMfmaIterateK_SwizzleA +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + + return c_vec; + } +}; + } // namespace ck_tile -- GitLab From cb0645bedca3650dcaed83d599dda664c727ce38 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:28:15 -0500 Subject: [PATCH 42/96] Add a scale op, related instances and examples (#1242) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add a scale op * Update the element op * Add instances * Add an example * Add a client example * Add a flag check * Revert flag check addition * Fix flag check * Update d strides in example * Update d strides in client example * Apply suggestions from code review Update copyright header Co-authored-by: Bartłomiej Kocot * Move the example * Move the client example * Update element op * Update example with the new element op * Add scalar layout * Update example * Update kernel for scalar Ds * Revert kernel changes * Update element op * Update example to use scales' pointers * Format * Update instances * Update client example * Move element op to unary elements * Update element op to work with values instead of pointers * Update instances to take element op as an argument * Update examples to use random scale values --------- Co-authored-by: Bartłomiej Kocot --- .../24_grouped_conv_activation/CMakeLists.txt | 4 + .../grouped_convnd_fwd_convscale/common.hpp | 316 ++++++++++++++++++ .../conv3d_fwd_convscale_fp8.cpp | 50 +++ example/62_convnd_activ/CMakeLists.txt | 1 + .../62_convnd_activ/convscale/CMakeLists.txt | 10 + .../convscale/convnd_fwd_convscale_common.hpp | 301 +++++++++++++++++ .../convnd_fwd_xdl_convscale_fp8.cpp | 88 +++++ .../run_convnd_fwd_convscale_example.inc | 104 ++++++ .../element/unary_element_wise_operation.hpp | 23 ++ ...ped_conv_fwd_xdl_outelementop_instance.hpp | 78 +++++ .../grouped_convolution_forward_convscale.hpp | 108 ++++++ .../CMakeLists.txt | 5 + ...scale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp | 62 ++++ 13 files changed, 1150 insertions(+) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp create mode 100644 example/62_convnd_activ/convscale/CMakeLists.txt create mode 100644 example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp create mode 100644 example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp create mode 100644 example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index d4d5c545c..9d9b86ad0 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -35,6 +35,10 @@ target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composa add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale +add_executable(client_conv3d_fwd_convscale_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations) # Bwd data bilinear add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp new file mode 100644 index 000000000..79af6f09b --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in; + float scale_wei; + float scale_out; + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp new file mode 100644 index 000000000..15d063c2f --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 5a35f9b60..d1cf5c4ec 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(binary) +add_subdirectory(convscale) add_subdirectory(multi_AB) add_subdirectory(unary) diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt new file mode 100644 index 000000000..21017a5c2 --- /dev/null +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -0,0 +1,10 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp new file mode 100644 index 000000000..978221f8e --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp new file mode 100644 index 000000000..a7d69ccff --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc new file mode 100644 index 000000000..797146060 --- /dev/null +++ b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index bddf9087f..3404ef193 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -961,6 +961,29 @@ struct Elu const float alpha_; }; +struct ConvScale +{ + __host__ __device__ ConvScale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + // support fastconvert of int8 to fp16 template diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp new file mode 100644 index 000000000..a9ef244c8 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = ck::f8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_outelementop_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp new file mode 100644 index 000000000..50ae1cd41 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt new file mode 100644 index 000000000..16ddbcb04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -0,0 +1,5 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVSCALE + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 000000000..cfc99f9dc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From ba82beb9bf4cef5df9f88f8e1b974c5a2dd996d0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 07:36:39 -0700 Subject: [PATCH 43/96] Bump rocm-docs-core from 1.2.1 to 1.3.0 in /docs/sphinx (#1324) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 1.2.1 to 1.3.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v1.2.1...v1.3.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 6ab8e14dd..e33f703c6 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.2.1 +rocm-docs-core==1.3.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 868c0044b..39d4b82a1 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.2.1 +rocm-docs-core==1.3.0 # via -r requirements.in six==1.16.0 # via -- GitLab From ac58cc5d1d4fe23fe5d08e16049d89397b8250e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 5 Jun 2024 20:01:29 +0200 Subject: [PATCH 44/96] Integrate universal gemm with conv forward (#1320) * Integrate universal gemm with conv fwd * Fix conv fwd wmma test * Fix instances * Remove direct load check --- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 4 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 10 +- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 1118 +++++++++++++++++ .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 105 +- ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 137 ++ ...vice_grouped_conv_fwd_xdl_mem_instance.hpp | 160 +++ .../gpu/grouped_convolution_forward.hpp | 35 +- .../grouped_convolution_forward_comp_xdl.inc | 112 ++ ...uped_convolution_forward_mem_inter_xdl.inc | 112 ++ ...uped_convolution_forward_mem_intra_xdl.inc | 112 ++ .../gpu/grouped_convolution_forward_xdl.inc | 2 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 14 + ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 64 + ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 64 + ...dl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp | 64 + ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 4 +- ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 66 + ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 66 + ...wgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp | 66 + ...wgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp | 66 + ...wgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp | 66 + ...wgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp | 66 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 13 + ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 54 + ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 54 + ...ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp | 54 + ..._gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp | 55 + ..._gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp | 55 + ...c_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp | 55 + ...c_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp | 55 + ...c_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp | 55 + ...c_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp | 55 + test/grouped_convnd_fwd/CMakeLists.txt | 10 +- ...l_wmma.cpp => test_grouped_convnd_fwd.cpp} | 2 +- 34 files changed, 2990 insertions(+), 40 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd_xdl_wmma.cpp => test_grouped_convnd_fwd.cpp} (98%) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c1c159101..c704cf059 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -674,7 +674,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle clear_workspace(); }; - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( stream_config, run_flush_cache, kernel, @@ -690,7 +690,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } else { - ave_time = launch_and_time_kernel_with_preprocess( + ave_time += launch_and_time_kernel_with_preprocess( stream_config, clear_workspace, kernel, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c532eec99..28ad91efd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -820,15 +820,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } } - else if(ck::is_lds_direct_load_supported()) - { - if constexpr(!(is_same_v || is_same_v || - is_same_v || is_same_v)) - { - return false; - } - } - else + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 000000000..986c41c51 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,1118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t batch_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t batch_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiD = DsDataType::Size() > 0; + static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + + // multi ABD not supported + static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported"); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + +#define GridwiseGemmV3TemplateParams \ + tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ + tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ + MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ + AComputeDataType, BComputeDataType + + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const index_t M = e_grid_desc_m_n.GetLength(I0); + const index_t N = e_grid_desc_m_n.GetLength(I1); + return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); + } + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t({}, {}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_as, + const void* p_bs, + const std::array&, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{}, + p_b_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_bk0_n_bk1_{ + MakeBGridDescriptor_BK0_N_BK1(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_a_grid_ = static_cast(p_as); + p_b_grid_ = static_cast(p_bs); + + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_M_N e_grid_desc_m_n_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + + gdy *= arg.num_group_; + + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + typename GridwiseGemm::Argument gemm_arg{ + arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1}; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.num_group_); + } + else + { + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.num_group_); + } + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 50e6f68e6..bda2ded95 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 } template - __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( @@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem) + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3 }); } } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } }; } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp new file mode 100644 index 000000000..7490ef223 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Compute friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp new file mode 100644 index 000000000..2388c4db0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 8602a82ff..54826503a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -17,6 +17,9 @@ #endif #ifdef CK_USE_XDL #include "grouped_convolution_forward_xdl.inc" +#include "grouped_convolution_forward_comp_xdl.inc" +#include "grouped_convolution_forward_mem_inter_xdl.inc" +#include "grouped_convolution_forward_mem_intra_xdl.inc" #endif #ifdef CK_USE_WMMA #include "grouped_convolution_forward_wmma.inc" @@ -182,7 +185,7 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } #endif } @@ -196,6 +199,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -204,6 +212,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -214,6 +227,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + op_ptrs); } #endif } @@ -266,6 +284,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); } #endif @@ -315,6 +338,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -325,6 +353,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc new file mode 100644 index 000000000..c93d6f441 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc new file mode 100644 index 000000000..b3913443d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc new file mode 100644 index 000000000..6874822e7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index e627d428d..aaac9a2af 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -75,7 +75,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( #ifdef CK_ENABLE_BF16 // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp new file mode 100644 index 000000000..9b1c7ef65 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp new file mode 100644 index 000000000..93e07e08f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 08770b861..2afbfdc38 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 000000000..3ae3fb518 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp new file mode 100644 index 000000000..cb7e91293 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp new file mode 100644 index 000000000..d787f4b04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp new file mode 100644 index 000000000..564428979 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp new file mode 100644 index 000000000..5b12dad5a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 579bea00d..e24dbcd2c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -8,6 +8,19 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 000000000..efc464060 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp new file mode 100644 index 000000000..3f3cd4b7d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp new file mode 100644 index 000000000..386c62261 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 000000000..6e7b4624b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 000000000..6fab8347d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp new file mode 100644 index 000000000..f52f21454 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp new file mode 100644 index 000000000..e5311888e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp new file mode 100644 index 000000000..9524378e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp new file mode 100644 index 000000000..49332076f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 4f245d63c..1eba91382 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,6 +1,10 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) + if(GPU_TARGETS MATCHES "gfx11") + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + else() + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + endif() endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp similarity index 98% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index dde8313f9..125e4dc48 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include -- GitLab From 8f5690c4bb0116ae55969546f5190dde8144ddf0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Jun 2024 22:38:26 -0700 Subject: [PATCH 45/96] Bump rocm-docs-core from 1.3.0 to 1.4.0 in /docs/sphinx (#1327) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.3.0 to 1.4.0. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.3.0...v1.4.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index e33f703c6..9c9706c66 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.3.0 +rocm-docs-core==1.4.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 39d4b82a1..313c30026 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.3.0 +rocm-docs-core==1.4.0 # via -r requirements.in six==1.16.0 # via -- GitLab From ce66277a762e29900affbd4844d73930b78a59b5 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 10 Jun 2024 14:48:49 -0500 Subject: [PATCH 46/96] Add a convinvscale op, related instances and examples (#1307) * Update the element op * Add an example * Add instances * Add a client example * make sure new instances only build on gfx9 * Update element op and its handling * Format * Update instances to take element op as an argument * Update examples to use random scale values * Format * Update client example with random scales * Format --------- Co-authored-by: illsilin --- .../24_grouped_conv_activation/CMakeLists.txt | 4 + .../common.hpp | 316 ++++++++++++++++++ .../conv3d_fwd_convinvscale_fp8.cpp | 50 +++ example/62_convnd_activ/CMakeLists.txt | 1 + .../convinvscale/CMakeLists.txt | 10 + .../convnd_fwd_convinvscale_common.hpp | 301 +++++++++++++++++ .../convnd_fwd_xdl_convinvscale_fp8.cpp | 88 +++++ .../run_convnd_fwd_convinvscale_example.inc | 104 ++++++ .../gpu/element/element_wise_operation.hpp | 20 -- .../element/unary_element_wise_operation.hpp | 23 ++ ...ouped_convolution_forward_convinvscale.hpp | 108 ++++++ .../CMakeLists.txt | 5 + ...scale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp | 62 ++++ 13 files changed, 1072 insertions(+), 20 deletions(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp create mode 100644 example/62_convnd_activ/convinvscale/CMakeLists.txt create mode 100644 example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp create mode 100644 example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp create mode 100644 example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 9d9b86ad0..29a2f3577 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -35,6 +35,10 @@ target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composa add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd convinvscale +add_executable(client_conv3d_fwd_convinvscale_fp8 + grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations) # Fwd convscale add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp new file mode 100644 index 000000000..7059e24d8 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convinvscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvInvscale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp new file mode 100644 index 000000000..775ea99ec --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convinvscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index d1cf5c4ec..fa5606773 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(binary) +add_subdirectory(convinvscale) add_subdirectory(convscale) add_subdirectory(multi_AB) add_subdirectory(unary) diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt new file mode 100644 index 000000000..07f42075b --- /dev/null +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -0,0 +1,10 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convinvscale) + add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp new file mode 100644 index 000000000..4b2ebf848 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp new file mode 100644 index 000000000..fbdfc7206 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convinvscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvInvscale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convinvscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc new file mode 100644 index 000000000..797146060 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 5499689c9..798c5580a 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -528,26 +528,6 @@ struct UnaryTypeConvert } }; -struct ConvInvscale -{ - /// @brief Op to multiply convolution results by inverted scale factors - /// @param e Output after scaling - /// @param c Convolution result - /// @param d0 Input scale factor - /// @param d1 Weights scale factor - /// @param d2 Output scale factor - template - __host__ __device__ void - operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; - - template <> - __host__ __device__ void operator()( - f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const - { - e = type_convert(c / d0 / d1 / d2); - }; -}; - } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 3404ef193..28514fb78 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -961,6 +961,29 @@ struct Elu const float alpha_; }; +struct ConvInvscale +{ + __host__ __device__ ConvInvscale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + struct ConvScale { __host__ __device__ ConvScale(float scale_in = 1.f, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp new file mode 100644 index 000000000..e7c24f884 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt new file mode 100644 index 000000000..bbbe18bea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt @@ -0,0 +1,5 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVINVSCALE + xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convinvscale_instance ${GROUPED_CONV3D_FWD_CONVINVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 000000000..bba72be7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvInvscale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 5fc1bee4c547d6af39743542d39498453d024f68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 11 Jun 2024 09:52:38 +0200 Subject: [PATCH 47/96] Fix nhwgc f16 wmma instances (#1328) --- .../gpu/grouped_convolution_forward.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 54826503a..ec5bd785a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -402,6 +402,17 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && -- GitLab From acda4c5a3c34c13b71475fdd963e61182bba8a76 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:41:56 -0500 Subject: [PATCH 48/96] Add instances for grouped conv fwd 3d with ConvScale for fp8@bf8->fp8 (#1325) * Add fp8 bf8 conv example * Add instances * Add client example * Add random scale values * Format --- .../24_grouped_conv_activation/CMakeLists.txt | 3 + .../grouped_convnd_fwd_convscale/common.hpp | 6 +- .../conv3d_fwd_convscale_fp8_bf8.cpp | 50 +++++++++++ .../62_convnd_activ/convscale/CMakeLists.txt | 2 + .../convnd_fwd_xdl_convscale_fp8_bf8.cpp | 88 +++++++++++++++++++ ...ped_conv_fwd_xdl_outelementop_instance.hpp | 38 ++++++++ .../grouped_convolution_forward_convscale.hpp | 27 ++++++ .../CMakeLists.txt | 3 +- ...e_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp | 62 +++++++++++++ 9 files changed, 275 insertions(+), 4 deletions(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp create mode 100644 example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 29a2f3577..b0c895d8a 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -43,6 +43,9 @@ target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kern add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations) +add_executable(client_conv3d_fwd_convscale_fp8_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) # Bwd data bilinear add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp index 79af6f09b..51eec5b1a 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp @@ -115,9 +115,9 @@ bool run_grouped_conv_fwd_convscale( SimpleDeviceMem wei(wei_mem_size); SimpleDeviceMem out(out_mem_size); - float scale_in; - float scale_wei; - float scale_out; + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); std::array in_strides; std::array wei_strides; diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp new file mode 100644 index 000000000..b38225f2b --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 21017a5c2..d6abb32f2 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -5,6 +5,8 @@ foreach(gpu IN LISTS GPU_TARGETS) add_custom_target(example_convnd_activ_xdl_convscale) add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) set(target 1) endif() endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp new file mode 100644 index 000000000..ab59e08a8 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp index a9ef244c8..6fbbaca7b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp @@ -15,6 +15,7 @@ namespace instance { using F32 = float; using F8 = ck::f8_t; +using BF8 = ck::bf8_t; template using S = ck::Sequence; @@ -72,6 +73,43 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp index 50ae1cd41..ed0198ef2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp @@ -39,6 +39,24 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instanc F8>>>& instances); #endif +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances); +#endif + template && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + op_ptrs); + } #endif } return op_ptrs; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt index 16ddbcb04..aef9c10c2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -1,5 +1,6 @@ # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE - xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp new file mode 100644 index 000000000..d63f58853 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 37a347e3807198400d6ee1c8401f7c2cbb1d426e Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 13 Jun 2024 16:12:20 +0800 Subject: [PATCH 49/96] Fix to the using of static_for in amd_buffer_addressing.hpp (#1337) * Add insert_dummy_dep_per_dword over-loading for length 64 * Fix insert_dummy_dep_per_dword and remove over-loading for length 64 * Remove blank lines --------- Co-authored-by: Po Yen Chen --- include/ck_tile/core/arch/amd_buffer_addressing.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 9c6e85f01..13e92ef0b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -552,8 +552,9 @@ namespace impl{ template CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) { - static_for<0, b.size(), 1>{}([&](auto i){ - asm volatile(" " : : "v"(b.get(i)) : "memory"); + constexpr auto kSize = remove_cvref_t::size(); + static_for<0, kSize, 1>{}([&](auto i){ + asm volatile(" " : : "v"(b.get(number{})) : "memory"); }); } #if 1 -- GitLab From dc1e9c5df9e022b130337cc31fd8a32f6ce1efa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 14 Jun 2024 16:53:03 +0200 Subject: [PATCH 50/96] Support large tensors in grouped conv fwd (#1332) * Support large tensors in grouped conv fwd * Multi ABD fixes * Fix calculate element space size --- .../impl/device_column_to_image_impl.hpp | 5 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 11 +- .../device_grouped_conv_bwd_weight_dl.hpp | 9 +- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 9 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 20 +-- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 9 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 18 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 9 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 163 ++++++++++++------ ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 125 ++++++++------ ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 9 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 9 +- .../device/impl/device_grouped_conv_utils.hpp | 60 +++---- .../impl/device_image_to_column_impl.hpp | 5 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 20 +-- .../transform_conv_fwd_to_gemm.hpp | 95 ++++++++-- .../test_grouped_convnd_fwd.cpp | 16 ++ 17 files changed, 368 insertions(+), 224 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 4c6546239..a7a366ffb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl independent_filter_strides, conv_filter_dilations, input_left_pads_with_offset, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c0fa9ad88..409e8c7b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,12 +93,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index bd264a3c8..83db2485a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -54,12 +54,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 3c33c7dbc..380a06e0d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -66,12 +66,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c704cf059..963f3f254 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -59,12 +59,9 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -116,12 +113,9 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy @@ -1268,7 +1262,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.Conv_G_; std::array in_out_batch_strides = { - arg.compute_ptr_offset_of_batch_.BatchStrideC_}; + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; const auto kernel = kernel_batched_elementwise, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 96854e9a8..3babd1896 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,12 +61,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 7cfbd8a8f..3bb53920b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -97,12 +97,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -266,7 +263,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -312,8 +310,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 6a4d97d7d..5c9d63e2b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd& c_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 28ad91efd..88fe38add 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -69,7 +69,8 @@ template @@ -85,7 +86,7 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t batch_count, + const index_t groups_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -93,18 +94,22 @@ __global__ void const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) + // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); + const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -121,13 +126,28 @@ __global__ void AsPointer p_as_grid_grp; BsPointer p_bs_grid_grp; - const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; }); + } - const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static_for<0, NumBTensor, 1>{}( @@ -137,7 +157,7 @@ __global__ void p_as_grid_grp, p_bs_grid_grp, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_batch_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -150,16 +170,16 @@ __global__ void } else { - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); GridwiseGemm::template Run( - p_as_grid + a_batch_offset, + p_as_grid + a_batch_offset + a_n_offset, p_bs_grid + b_batch_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_batch_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -175,7 +195,7 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = batch_count; + ignore = groups_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -183,7 +203,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; ignore = block_2_ctile_map; #endif } @@ -309,7 +330,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t Conv_N) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -321,7 +343,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + Conv_N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -347,11 +370,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -363,24 +387,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Pass e_g_n_k_wos_lengths for logical broadcast. static auto MakeDsGridDescriptor_M_N( const std::array& e_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const index_t Conv_N) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N); }, Number{}); } // desc for problem definition using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it @@ -468,6 +493,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -477,12 +508,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads)}, + input_right_pads, + conv_N_per_block_)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, a_grid_desc_ak0_m_ak1_{ GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ @@ -490,7 +522,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - compute_ptr_offset_of_batch_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -511,8 +544,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if constexpr(isMultiA || isMultiB) { static_for<0, NumATensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // type is not tuple) @@ -524,16 +557,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { // p_as is tuple p_as_grid_(i) = static_cast(p_as[i.value]); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + compute_ptr_offset_of_n_.BatchStrideA_(i) = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } else { // if MultiB and not MultiA then p_as is single pointer p_as_grid_(i) = static_cast(p_as); + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } }); static_for<0, NumBTensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; using DataType = remove_cvref_t>; // It is possible that one of the AB is a pointer and one is a tuple. @@ -553,8 +593,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers p_as_grid_(I0) = static_cast(p_as); @@ -570,13 +611,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_(i) = static_cast(p_ds[i]); // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; // D desc ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); }); - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; // populate desc for Ds/E if constexpr(isMultiA || isMultiB) @@ -638,6 +682,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; + index_t conv_N_per_block_; + AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; @@ -655,7 +701,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // for computing batch offset ComputePtrOffsetOfStridedBatch - compute_ptr_offset_of_batch_; + compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -689,8 +736,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.Print(); } - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; + const index_t gdz = 1; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -721,6 +772,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -728,7 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_, @@ -744,7 +796,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } else { @@ -763,6 +816,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -770,7 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple @@ -786,7 +840,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 986c41c51..ba9d967e9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -60,7 +60,7 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); + + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -97,9 +99,9 @@ __global__ void CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset, karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, p_shared, karg, a_grid_desc_ak0_m_ak1, @@ -114,7 +116,7 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy @@ -154,9 +159,9 @@ __global__ void CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset, karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, p_shared_0, p_shared_1, karg, @@ -294,7 +299,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t Conv_N) + { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -306,7 +313,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + Conv_N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -350,11 +358,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) + { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -363,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } // desc for problem definition - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; #define GridwiseGemmV3TemplateParams \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ @@ -396,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -429,6 +439,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_b_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -438,13 +454,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads)}, + input_right_pads, + conv_N_per_block_)}, b_grid_desc_bk0_n_bk1_{ MakeBGridDescriptor_BK0_N_BK1(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - compute_ptr_offset_of_batch_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -459,15 +477,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - // A/B/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers p_a_grid_ = static_cast(p_as); p_b_grid_ = static_cast(p_bs); - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; e_grid_desc_mblock_mperblock_nblock_nperblock_ = MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); @@ -488,6 +508,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // tensor descriptors for problem definiton index_t num_group_; + index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; @@ -496,7 +517,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -538,11 +560,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); - gdy *= arg.num_group_; + gdy *= arg.num_group_ * num_workgroups_per_Conv_N; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); @@ -579,7 +604,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, arg.num_group_); } else @@ -594,7 +620,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, arg.num_group_); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ab1c4fc08..114fcbfcf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 24bd0f242..d5cc5dc75 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 9ae10441f..c20e5d36f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -68,14 +68,14 @@ template struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> + enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, - Array& BatchStrideBs, - Array& BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideAs), BatchStrideB_(BatchStrideBs), BatchStrideDs_(BatchStrideDs), @@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch as_offset; static_for<0, NumATensor, 1>{}( - [&](auto i) { as_offset(i) = g_idx * static_cast(BatchStrideA_[i]); }); + [&](auto i) { as_offset(i) = static_cast(g_idx) * BatchStrideA_[i]; }); return as_offset; } @@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch bs_offset; static_for<0, NumBTensor, 1>{}( - [&](auto i) { bs_offset(i) = g_idx * static_cast(BatchStrideB_[i]); }); + [&](auto i) { bs_offset(i) = static_cast(g_idx) * BatchStrideB_[i]; }); return bs_offset; } @@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - Array BatchStrideA_; - Array BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template struct ComputePtrOffsetOfStridedBatch> + enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + Array BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideDs_(BatchStrideDs), @@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch(BatchStrideA_); + return static_cast(g_idx) * BatchStrideA_; } __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB_); + return static_cast(g_idx) * BatchStrideB_; } __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const { Array ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - ck::index_t BatchStrideA_; - ck::index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 52aeefa3a..9ebcb2b8c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 82d010a99..dc639e995 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,12 +60,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -155,12 +152,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index e2f75142d..3097a3293 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,6 +20,71 @@ struct TransformConvFwdToGemm static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static long_index_t + calculate_element_space_size_impl(const std::array& lengths, + const std::array& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static index_t GetSplitedNSize(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const index_t N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(index_t least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Not possible to support even after split N. + // Too large tensor. + return N; + } + } + else + { + // Split N is not needed. + return N; + } + } + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template & conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Wi = a_g_n_c_wis_lengths[3]; @@ -151,9 +216,10 @@ struct TransformConvFwdToGemm const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) + { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Hi = a_g_n_c_wis_lengths[3]; @@ -276,13 +342,14 @@ struct TransformConvFwdToGemm const std::array& b_g_k_c_xs_lengths, const std::array& /* b_g_k_c_xs_strides */, const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, + const std::array& /* c_g_n_k_wos_strides*/, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) + { - const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; const index_t Di = a_g_n_c_wis_lengths[3]; @@ -478,9 +545,9 @@ struct TransformConvFwdToGemm bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + const std::array& /* c_g_n_k_wos_strides */, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t NHoWo = @@ -502,9 +569,9 @@ struct TransformConvFwdToGemm is_same_v, bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + const std::array& c_g_n_k_wos_strides, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const auto KStride = I1; @@ -525,9 +592,9 @@ struct TransformConvFwdToGemm typename std::enable_if, bool>::type = false> static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + const std::array& c_g_n_k_wos_strides, + const index_t N) { - const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t KStride = c_g_n_k_wos_strides[2]; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 125e4dc48..21fe7992a 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple>; +using KernelTypes2dLargeCases = ::testing::Types>; + template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd { @@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd { }; +template +class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd +{ +}; + TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases); TYPED_TEST(TestGroupedConvndFwd1d, Test1D) { @@ -131,3 +139,11 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } + +TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) +{ + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} -- GitLab From e02103168a21280d51faf9c5c4981d804e170c44 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Sun, 16 Jun 2024 20:33:47 -0500 Subject: [PATCH 51/96] disabled lds direct load inline asm (#1331) --- include/ck/ck.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 55f562061..32eea551f 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -155,7 +155,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 // LDS direct loads using inline assembly -#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 // set stochastic rounding as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 1 -- GitLab From 17ed368f5882dc71f70511bef86ce0831fd12f4d Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 17 Jun 2024 17:16:46 +0800 Subject: [PATCH 52/96] [CK_TILE][FA] using pk f16_f32 (#1343) * [CK_TILE][FA] using pk f16_f32 * correct a error --- include/ck_tile/core/arch/arch.hpp | 11 +++-- include/ck_tile/core/config.hpp | 4 ++ .../ck_tile/core/tensor/tile_elementwise.hpp | 43 ++++++++++++++++++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 10 ++++- 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 888f0e728..4a69f67ae 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } CK_TILE_DEVICE void block_sync_lds() { #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); #else __syncthreads(); #endif diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 10045d8f7..344343d93 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -167,6 +167,10 @@ #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif +#ifndef CK_TILE_USE_PK_FP16_TILE_CAST +#define CK_TILE_USE_PK_FP16_TILE_CAST 0 +#endif + // TODO: better solve this inside compiler #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 48762b722..5fecd19dc 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) namespace impl { // TODO: this is ugly template -CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) +CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) { #if defined(__gfx94__) // This API is designed to use the _pk_ serious of function @@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) #endif } +template +CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + // TODO: this is rtz cvt, need be very careful + for(index_t i = 0; i < thread_buffer_size_pk; i++) + { + auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0], + in_dstr_tensors.get_thread_buffer()[2 * i + 1]); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y; + } + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) float> && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { - return impl::cast_tile_pk_fp8x4(src_tensor); + return impl::cast_tile_pk_fp8_fp32(src_tensor); } +#if CK_TILE_USE_PK_FP16_TILE_CAST + else if constexpr(std::is_same_v && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + { + return impl::cast_tile_pk_fp16_fp32(src_tensor); + } +#endif #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 9939a474b..21784fc2d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync randval_dram_window); } - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); // STAGE 3, KV gemm if constexpr(k1_loops > 1) -- GitLab From 933951ed48f4f255be36e065cf77bb97dcca3bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 18 Jun 2024 10:26:49 +0200 Subject: [PATCH 53/96] Fix continous dim selection in contraction (#1336) * Fix continous dim selection in contraction * Fixes --- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 42 +++++++--------- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 43 ++++++++--------- .../device/impl/device_contraction_utils.hpp | 48 +++++++++++++++---- test/contraction/test_contraction_xdl.cpp | 6 +++ 4 files changed, 80 insertions(+), 59 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 33e03a85e..dae16612c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { - as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; - as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1; - as_max_read_elems_[i] = + tie(as_continous_dim_[i], as_max_read_elems_[i]) = CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); } for(index_t i = 0; i < NumBTensor; ++i) { - bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; - bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1; - bs_max_read_elems_[i] = + tie(bs_continous_dim_[i], bs_max_read_elems_[i]) = CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); } for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); } // pointers @@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - std::array as_mz_consecutive_; - std::array as_kz_consecutive_; - std::array bs_nz_consecutive_; - std::array bs_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + std::array as_continous_dim_; + std::array bs_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; std::array as_max_read_elems_; std::array bs_max_read_elems_; @@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_a_vector_size = arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; const bool valid_a_access_dim_m = - ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; + ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0; const bool valid_a_access_dim_k = - ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; + ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; if(!((valid_a_vector_size && valid_a_access_dim) || ABlockTransferSrcScalarPerVector == 1)) @@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_b_vector_size = arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; const bool valid_b_access_dim_n = - BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; + BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0; const bool valid_b_access_dim_k = - BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; + BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; if(!((valid_b_vector_size && valid_b_access_dim) || BBlockTransferSrcScalarPerVector == 1)) @@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_d_vector_size = arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. - const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1; if(!((valid_d_vector_size && valid_d_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { @@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_e_vector_size = arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. - const bool valid_e_access_dim = arg.e_nz_consecutive_; + const bool valid_e_access_dim = arg.e_continous_dim_ == 1; if(!((valid_e_vector_size && valid_e_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 9d5b74be6..f1bc6a226 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; - a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1; - a_max_read_elems_ = + tie(a_continous_dim_, a_max_read_elems_) = CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; - b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1; - b_max_read_elems_ = + tie(b_continous_dim_, b_max_read_elems_) = CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = + tie(e_continous_dim_, e_max_write_elems_) = CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } @@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - bool a_mz_consecutive_; - bool a_kz_consecutive_; - bool b_nz_consecutive_; - bool b_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + index_t a_continous_dim_; + index_t b_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; index_t a_max_read_elems_; index_t b_max_read_elems_; @@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_a_vector_size = arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; - const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; - const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; if(!(valid_a_vector_size && valid_a_access_dim)) @@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_b_vector_size = arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; - const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; - const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; if(!(valid_b_vector_size && valid_b_access_dim)) @@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. const bool valid_d_access_dim = - arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_d_vector_size && valid_d_access_dim)) { valid_ds_access = false; @@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. const bool valid_e_access_dim = - arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_e_vector_size && valid_e_access_dim)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 838305f18..1b0db73fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector= begin_idx; --dim_idx) { if(strides[dim_idx] == consecutive_stride) @@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vectortemplate Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + + // special cases + this->template Run<2>({{1, 1}, {16, 8}, {8, 16}}); + this->template Run<2>({{8, 16}, {16, 8}, {1, 1}}); + this->template Run<2>({{8, 16}, {1, 1}, {8, 16}}); + this->template Run<2>({{1, 1}, {1, 1}, {1, 1}}); } -- GitLab From e2d139201b8041726fc5f4c25f6689c4a83a6d6e Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 18 Jun 2024 16:01:49 +0200 Subject: [PATCH 54/96] Switch to universal gemm in grouped gemm tile loop (#1335) * switch to universal gemm in grouped gemm tile loop * minor fixes * add reviewers comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- ...emm_multiply_bias_fastgelu_xdl_bf16_i8.cpp | 2 +- .../grouped_gemm_multiply_xdl_bf16_i8.cpp | 2 +- .../grouped_gemm_multiple_d_xdl_fp16.cpp | 2 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 8 +- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 8 +- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 4 +- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 4 +- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 544 ++++++++++++------ .../gpu/grid/block_to_ctile_map.hpp | 45 ++ .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 112 +++- ...pp => grouped_gemm_tile_loop_multiply.hpp} | 173 +++++- .../gpu/grouped_gemm_tile_loop/CMakeLists.txt | 18 +- ...ile_loop_f16_f16_f16_mk_kn_mn_instance.cpp | 20 +- ...ile_loop_f16_f16_f16_mk_nk_mn_instance.cpp | 26 +- ...le_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp | 93 +++ ...i8_bf16_mk_kn_mn_comp_default_instance.cpp | 35 ++ ...8_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 35 ++ ...bf16_mk_kn_mn_comp_mnkpadding_instance.cpp | 35 ++ ..._bf16_mk_kn_mn_comp_mnpadding_instance.cpp | 35 ++ ...ultiply_bf16_i8_bf16_mk_kn_mn_instance.cpp | 190 ++++-- ..._bf16_mk_kn_mn_mem_v1_default_instance.cpp | 36 ++ ...bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 36 ++ ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 36 ++ ...f16_mk_kn_mn_mem_v1_mnpadding_instance.cpp | 36 ++ ..._bf16_mk_kn_mn_mem_v2_default_instance.cpp | 36 ++ ...bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 36 ++ ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 36 ++ ...f16_mk_kn_mn_mem_v2_mnpadding_instance.cpp | 36 ++ ...ly_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp | 40 ++ ...astgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 41 ++ ...astgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 39 ++ ...e_grouped_gemm_multiply_tile_loop_impl.hpp | 347 +++++++++++ profiler/src/CMakeLists.txt | 1 + ...rofile_grouped_gemm_multiply_tile_loop.cpp | 133 +++++ 34 files changed, 1956 insertions(+), 324 deletions(-) rename library/include/ck/library/tensor_operation_instance/gpu/{grouped_gemm_tile_loop_multply.hpp => grouped_gemm_tile_loop_multiply.hpp} (55%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp index 36637df46..4b284c74d 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp index f71b6a13f..6cc83e06f 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 2b891dd6f..965a0e7e3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -63,7 +63,7 @@ using DeviceGemmInstance = //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on struct ProblemSize final diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 0a7ad545b..7dd86468b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; @@ -446,12 +446,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 9acfd0085..dad643ffa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) { @@ -646,12 +646,12 @@ struct BlockwiseGemmXdlops_pipeline_v2 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 3acfe0daa..52f48d0e4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 75569150b..51ce8ae61 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % HotloopUnroll == 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 36cbd1cd2..70011124f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -42,16 +43,22 @@ namespace device { template + typename CDEElementwiseOperation, + BlockGemmPipelineScheduler BlkGemmPipeSched, + BlockGemmPipelineVersion BlkGemmPipelineVer> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -67,6 +74,7 @@ __global__ void constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); @@ -81,27 +89,8 @@ __global__ void index_t gemm_tile_id_start = 0; index_t gemm_tile_id_end = 0; - using AGridDescMK = - remove_cvref_t( - 1, 1, 1))>; - using BGridDescNK = - remove_cvref_t( - 1, 1, 1))>; - using EGridDescMN = - remove_cvref_t( - 1, 1, 1))>; - using DsGridDescMN = - remove_cvref_t( - {}, {}, {}))>; - index_t M = 0, N = 0, K = 0; - index_t StrideA, StrideB, StrideE; - std::array StrideDs; - AGridDescMK a_grid_desc_mk; - BGridDescNK b_grid_desc_nk; - EGridDescMN e_grid_desc_mn; - DsGridDescMN ds_grid_desc_mn; auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); do @@ -127,31 +116,13 @@ __global__ void } b2c_tile_map = - OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset); + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); gemm_tile_id_start = group_offset; gemm_tile_id_end = group_offset + grid_size_grp; } - StrideA = gemm_desc_ptr[group_id].StrideA; - StrideB = gemm_desc_ptr[group_id].StrideB; - StrideDs = gemm_desc_ptr[group_id].StrideDs; - StrideE = gemm_desc_ptr[group_id].StrideE; - - a_grid_desc_mk = - GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); - b_grid_desc_nk = - GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); - e_grid_desc_mn = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - static_for<0, NumDTensor, 1>{}([&](auto j) { - using DLayout = remove_cvref_t>; - ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N( - M, N, StrideDs[j]); - }); - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); DsGridPointer p_ds_grid; @@ -160,42 +131,268 @@ __global__ void p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); }); - bool has_main_kblock_loop = - GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{})); + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + // Update tile offset if we have moved within group b2c_tile_map.UpdateTileOffset(tile_offset); - if(has_main_kblock_loop) + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid, - gemm_desc_ptr[group_id].p_e_grid, - static_cast(p_shared), - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } } else { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid, - gemm_desc_ptr[group_id].p_e_grid, - static_cast(p_shared), - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } } tile_id += get_grid_size(); @@ -253,10 +450,12 @@ template + typename CDEShuffleBlockTransferScalarPerVectors, + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + typename ComputeTypeA = EDataType, + typename ComputeTypeB = ComputeTypeA> + struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop : public DeviceGroupedGemmTileLoop; - - template - struct OffsettedBlockToCTileMap - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, - index_t group_offset, - index_t tile_offset) - : block_to_ctile_map_{block_to_ctile_map}, - group_offset_{group_offset}, - tile_offset_{tile_offset} - { - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - return block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; - template - __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - return block_to_ctile_map_.CalculateGridSize(M, N); - } - - __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t group_offset_; - index_t tile_offset_; - }; - - using KernelArguments = GroupedGemmTileLoopKernelArguments; - using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt; - using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArguments = GroupedGemmTileLoopKernelArguments; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; // Argument struct Argument : public BaseArgument @@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const void* p_dev_gemm_args_; int occupancy_num_blocks_; int gpu_cu_count_; - const std::vector& gemm_descs_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); } @@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop << std::endl; } + // run multiple kernels + return launch_and_time_kernel(stream_config, kernel, dim3(grid_size), @@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return false; } - using DsGridDescMN = remove_cvref_t< - decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N( - {}, {}, {}))>; - bool supported = true; - for(const auto& gdesc : arg.gemm_descs_) + constexpr index_t k_batch = 1; + for(index_t i = 0; i < arg.group_count_; ++i) { - const auto M = gdesc.M_; - const auto N = gdesc.N_; - const auto K = gdesc.K_; - - const auto StrideA = gdesc.stride_A_; - const auto StrideB = gdesc.stride_B_; - const auto StrideE = gdesc.stride_C_; - const auto& StrideDs = gdesc.stride_Ds_; - - // If M dimension is unknown at launch time then validate just NK. - // If N or K dim is zero (or unknown) then the vector loads responsibility lies on - // the user. - if(N * K == 0) - continue; - - const auto a_grid_desc_mk = - GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); - const auto b_grid_desc_nk = - GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); - const auto e_grid_desc_mn = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - DsGridDescMN ds_grid_desc_mn; - static_for<0, NumDTensor, 1>{}([&](auto j) { - using DLayout = remove_cvref_t>; - ds_grid_desc_mn(j) = - GridwiseGemm::template MakeEGridDescriptor_M_N( - M, N, StrideDs[j]); - }); - - const auto b2c_tile_map = Block2ETileMap(M, N); - - if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map) && - GridwiseGemm::template CheckTensorTransfersValidity( - M, N, K))) + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + using GridArg = typename GridwiseGemm::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," - << K << "] are not supported by current template parameters!" - << " In " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__; - } - supported = false; + return false; } + + supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); } return supported; @@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; int occupancy, num_cu; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; int occupancy, num_cu; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { auto str = std::ostringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" << "<" @@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", " << getGemmSpecializationString(GemmSpec) << ", " - << PipelineVer << ", " - << LoopSched + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 84b00fcbd..e751691c4 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap UnderlyingBlockToCTileMap block_to_ctile_map_; index_t block_start_; }; +// second version with 2 offsets +template +struct OffsettedBlockToCTileMap2 +{ + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t group_offset, + index_t tile_offset) + : block_to_ctile_map_{block_to_ctile_map}, + group_offset_{group_offset}, + tile_offset_{tile_offset} + { + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t group_offset_; + index_t tile_offset_; +}; /** * @brief Simple tile mapping which creates 3D grid of block of threads. diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index f0590c494..3a1ac6c6d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -189,55 +189,55 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { - return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); } - __host__ static auto CalculateMPadded(index_t M) + __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); } - __host__ static auto CalculateNPadded(index_t N) + __host__ __device__ static auto CalculateNPadded(index_t N) { return math::integer_least_multiple(N, NPerBlock); } - __host__ static auto CalculateKPadded(index_t K) + __host__ __device__ static auto CalculateKPadded(index_t K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; } - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); } - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); } - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { auto K_t = K_Batch * KPerBlock; return (K + K_t - 1) / K_t * KPerBlock; } - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) { constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); auto K_t = K_Batch * KReadVec; return (K + K_t - 1) / K_t * KReadVec; } - __host__ static auto CalculateMBlock(index_t M) + __host__ __device__ static auto CalculateMBlock(index_t M) { return math::integer_divide_ceil(M, MPerBlock); } - __host__ static auto CalculateNBlock(index_t N) + __host__ __device__ static auto CalculateNBlock(index_t N) { return math::integer_divide_ceil(N, NPerBlock); } @@ -520,14 +520,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 struct Problem { - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideC_, - index_t KBatch_) + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) : M{M_}, N{N_}, K{K_}, @@ -1180,14 +1180,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 return true; } - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return BlockwiseGemmPipe::BlockHasHotloop(num_loop); } - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) { const index_t num_loop = K / KPerBlock; @@ -1210,8 +1210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 - using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; template ( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1244,9 +1272,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -1653,6 +1678,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) + { + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1672,9 +1729,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp similarity index 55% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp index f7c031776..3298ad940 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp @@ -17,7 +17,150 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( op_ptrs); } } @@ -132,7 +296,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { @@ -199,7 +362,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { @@ -266,7 +428,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt index cbfcf8d22..0ba84c5cd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt @@ -5,8 +5,22 @@ set(GROUPED_GEMM_TILE_LOOP_INSTANCES) list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp - - device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp index 505afbdff..a41e6465b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp @@ -38,16 +38,16 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_inst //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp index 9653d3eef..32c3829d1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp @@ -37,19 +37,19 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_inst //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp new file mode 100644 index 000000000..d943376a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = + std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..684877443 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..bb2ea76aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..7439433f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..b3afed0fd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp index 0f62510a3..c98328e52 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -31,51 +31,63 @@ using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastG using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple< -// clang-format off +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if 1 - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> -#endif -#if 0 - //comp - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - - //latency - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - - //mem - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8> -#endif + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on >; +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = + std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances) { + // comp add_device_operation_instances( instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - Multiply>{}); -} + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmDefault>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNKPadding>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNPadding>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmKPadding>{}); + // mem + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Intrawave>{}); -void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Row, - BF16, - I8, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>>>& instances) -{ add_device_operation_instances( instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - MultiplyAdd>{}); + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Interwave>{}); } void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..b6e5961cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..0662bd5fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..cb6781b7b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp new file mode 100644 index 000000000..0f2c07bf5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..bd003f013 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..d00317844 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..5810b8a3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp new file mode 100644 index 000000000..8b4b37ed8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..13b0622e2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + MultiplyAdd>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + MultiplyAdd>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..33696e281 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + MultiplyAddFastGelu>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + MultiplyAddFastGelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..f6e72ac2d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + MultiplyFastGelu>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + MultiplyFastGelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp new file mode 100644 index 000000000..f66564416 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideEs, + int n_warmup = 10, + int n_iter = 50) +{ + using CDataType = EDataType; + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideEs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> d_m_n; + std::vector> e_m_n_host_results; + std::vector> e_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + d_m_n.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], DLayout{}))); + e_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + e_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i + << "]:" << e_m_n_device_results[i].mDesc << std::endl; + } + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(b_k_n[i]); + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(d_m_n[i]); + break; + case 2: + ck::utils::FillUniformDistribution{.0, 1.}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(b_k_n[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(d_m_n[i]); + break; + default: + ck::utils::FillConstant{1}(a_m_k[i]); + ck::utils::FillConstant{1}(b_k_n[i]); + ck::utils::FillConstant{1}(d_m_n[i]); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementOp = ck::tensor_operation::element_wise::Multiply; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, d_device_buf, e_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + d_device_buf.reserve(group_count); + e_device_buf.reserve(group_count); + + std::vector p_a, p_b, p_d; + constexpr ck::index_t NumDTensor = 1; + auto p_ds = std::vector>{}; + std::vector p_e; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_ds.reserve(group_count); + p_e.reserve(group_count); + + using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + + std::vector gemm_descs; + std::vector gemm_kargs; + + gemm_descs.reserve(group_count); + gemm_kargs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + d_device_buf.emplace_back( + std::make_unique(sizeof(DDataType) * d_m_n[i].mDesc.GetElementSpaceSize())); + e_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + d_device_buf[i]->ToDevice(d_m_n[i].mData.data()); + e_device_buf[i]->SetZero(); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_ds.push_back({d_device_buf[i]->GetDeviceBuffer()}); + p_e.push_back(e_device_buf[i]->GetDeviceBuffer()); + + gemm_descs.push_back( + {0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}}); + gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(), + b_device_buf[i]->GetDeviceBuffer(), + {d_device_buf[i]->GetDeviceBuffer()}, + e_device_buf[i]->GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {StrideDs[i]}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + Tensor c_m_n({Ms[i], Ns[i]}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument( + a_m_k[i], b_k_n[i], c_m_n, a_element_op, b_element_op, c_element_op); + ref_invoker.Run(ref_argument); + + for(int m = 0; m < Ms[i]; ++m) + { + for(int n = 0; n < Ns[i]; ++n) + { + cde_element_op(e_m_n_host_results[i](m, n), c_m_n(m, n), d_m_n[i](m, n)); + } + } + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + cde_element_op); + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + std::string gemm_name = gemm_ptr->GetTypeString(); + + DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + gemm_kargs.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data()); + instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i], + e_m_n_host_results[i]); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "e_device: ", e_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_host : ", e_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + if(time_kernel) + { + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(EDataType) * Ms[i] * Ns[i] + // D matrix + sizeof(EDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 1cfcbfff6..fa0eb6f88 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -43,6 +43,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) diff --git a/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp b/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp new file mode 100644 index 000000000..5cf0af5ec --- /dev/null +++ b/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 +}; + +enum struct GemmDataType +{ + BF16_INT8_BF16_BF16, // 0 +}; + +#define OP_NAME "grouped_gemm_multiply_tile_loop" +#define OP_DESC "Grouped GEMM Multiply Multiple D Tile Loop" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int profile_grouped_gemm_tile_loop(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: bf16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "optional:\n" + << "arg14: number of warm-up cycles (default 1)\n" + << "arg15: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + std::vector StrideDs(StrideCs); + + int n_warmup = 10; + int n_iter = 50; + if(argc == 16) + { + n_warmup = std::stoi(argv[14]); + n_iter = std::stoi(argv[15]); + } + + if(data_type == GemmDataType::BF16_INT8_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_multiply_tile_loop_impl< + ck::bhalf_t, + int8_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideCs, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop); -- GitLab From 8faec23cb431e38e4d08f6729a9a8f1e136dd7d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 18 Jun 2024 22:05:30 +0200 Subject: [PATCH 55/96] Add read_first_lane function for int64 (#1347) --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 9 +++-- .../device_grouped_conv_bwd_weight_dl.hpp | 9 +++-- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 9 +++-- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 18 ++++++---- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 9 +++-- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 17 ++++++---- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 34 ++++++++++++------- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 6 ++-- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 18 ++++++---- .../ck/utility/amd_wave_read_first_lane.hpp | 24 ++++++++++++- 10 files changed, 107 insertions(+), 46 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 409e8c7b8..5e9da459c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -93,9 +93,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 83db2485a..86091aeba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -54,9 +54,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 380a06e0d..7f88ea692 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -66,9 +66,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 963f3f254..f4f496fc1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -59,9 +59,12 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -113,9 +116,12 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 3bb53920b..ce86ec54e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -97,9 +97,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 88fe38add..f5a8d4e9f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -106,10 +106,12 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); - const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); - const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -170,10 +172,13 @@ __global__ void } else { - const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); GridwiseGemm::template Run( p_as_grid + a_batch_offset + a_n_offset, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index ba9d967e9..415ae3d49 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -85,12 +85,17 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); - const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); - - const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); - const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -142,12 +147,17 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); - const long_index_t a_batch_offset = compute_ptr_offset_of_groups.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_groups.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_groups.GetEPtrOffset(g_idx); - - const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); - const long_index_t e_n_offset = compute_ptr_offset_of_n.GetEPtrOffset(n_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index 114fcbfcf..2170a5829 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -161,11 +161,11 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index dc639e995..49a6dc3b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -60,9 +60,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -152,9 +155,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); - const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); - const long_index_t e_batch_offset = compute_ptr_offset_of_batch.GetEPtrOffset(g_idx); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 741b2975a..d6e1eab31 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -95,11 +95,33 @@ using get_carrier_t = typename get_carrier::type; } // namespace detail +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + __device__ inline int32_t amd_wave_read_first_lane(int32_t value) { return __builtin_amdgcn_readfirstlane(value); } +__device__ inline int64_t amd_wave_read_first_lane(int64_t value) +{ + constexpr unsigned object_size = sizeof(int64_t); + constexpr unsigned second_part_offset = object_size / 2; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) std::byte to_obj[object_size]; + + using Sgpr = uint32_t; + + *reinterpret_cast(to_obj) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj)); + *reinterpret_cast(to_obj + second_part_offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + second_part_offset)); + + return *reinterpret_cast(to_obj); +} + template < typename Object, typename = std::enable_if_t && std::is_trivially_copyable_v>> -- GitLab From 1973903f49947d33af88c41d173429a0057d8a59 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Wed, 19 Jun 2024 10:37:22 +0800 Subject: [PATCH 56/96] Hacking ck_tile fmha Dropout facility (#1344) * Add NullBlockDropout to be used when kHasDropout is false * Change to BlockDropout::Run() for forward to reduce conditional checkings * Re-format files --------- Co-authored-by: PoYen, Chen --- .../ck_tile/ops/fmha/block/block_dropout.hpp | 65 ++++++++++++++----- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 40 +++++------- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 10 +-- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 10 +-- 4 files changed, 79 insertions(+), 46 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 1f0fe2bd6..7ebb306cc 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -8,6 +8,20 @@ namespace ck_tile { +struct NullBlockDropout +{ + template + __host__ __device__ static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + struct BlockDropout { CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, @@ -195,6 +209,42 @@ struct BlockDropout MakeRandValLdsShuffleTileDistribution()); const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + if(is_store_randval) + { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, + reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + // save to Global + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + }; static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); @@ -232,23 +282,8 @@ struct BlockDropout : PComputeDataType(0); }); }); - // save to Global - if(is_store_randval) - { - const auto randval_store = cast_tile(randval); - store_tile(randval_dram_window, randval_store); - move_tile_window(randval_dram_window, {0, kNPerStep}); - } }); - if(is_store_randval) - { - move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); - } }); - if(is_store_randval) - { - move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); - } } template ::max(); - uint64_t drop_seed = 0; - uint64_t drop_offset = 0; - bool is_store_randval = false; - - if constexpr(kHasDropout) - { - rp_undrop = kargs.rp_undrop; - p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; - drop_seed = kargs.drop_seed; - drop_offset = kargs.drop_offset; - is_store_randval = kargs.is_store_randval; - } - BlockDropout dropout(i_batch, - i_nhead, - kargs.num_head_q, - drop_seed, - drop_offset, - rp_undrop, - p_undrop_in_uint8_t, - is_store_randval); + auto dropout = [&]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch, + i_nhead, + kargs.num_head_q, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); auto randval_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto randval_dram_window_lengths = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 06ce3a651..a392f0124 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -100,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr const char* name = "qr"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -139,7 +141,7 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -246,7 +248,7 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( + auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = @@ -486,7 +488,7 @@ struct BlockFmhaPipelineQRKSVS if constexpr(kHasDropout) { - dropout.Run( + dropout.template Run( smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); } @@ -618,7 +620,7 @@ struct BlockFmhaPipelineQRKSVS PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 21784fc2d..e9a14ca5a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -112,6 +112,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr const char* name = "qr_async"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -151,7 +153,7 @@ struct BlockFmhaPipelineQRKSVSAsync PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -298,7 +300,7 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.MakeRandvalDramWindow( + auto randval_dram_window = dropout.template MakeRandvalDramWindow( randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = @@ -571,7 +573,7 @@ struct BlockFmhaPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.Run( + dropout.template Run( randval_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, @@ -728,7 +730,7 @@ struct BlockFmhaPipelineQRKSVSAsync PositionEncoding position_encoding, float scale_s, void* smem_ptr, - BlockDropout& dropout) const + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, -- GitLab From 8db331a511e756c31271b6b60cca9bfadcae854b Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 19 Jun 2024 13:47:18 -0500 Subject: [PATCH 57/96] Remove gfx900 and gfx906 from default target device to reduce package size (#1351) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f9e44583..e8626b2cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,7 +112,7 @@ message("checking which targets are supported") #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) -- GitLab From e3f44659cf77df8c3de15eb14baffd58be6ac550 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 20 Jun 2024 11:36:42 +0800 Subject: [PATCH 58/96] Fix in dropout lambda to avoid the compiling issue on some docker/compiler envs (#1350) --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 6db432c83..5ecc3a4d8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -744,11 +744,11 @@ struct FmhaFwdKernel } }(); - auto dropout = [&]() { + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { - return BlockDropout{i_batch, - i_nhead, + return BlockDropout{i_batch_, + i_nhead_, kargs.num_head_q, kargs.drop_seed, kargs.drop_offset, -- GitLab From 0162a5f6ba60a04ce6a134304e09f933e46dbcfa Mon Sep 17 00:00:00 2001 From: ThruptiRajLakshmanaGowda Date: Thu, 20 Jun 2024 09:24:54 -0500 Subject: [PATCH 59/96] Adding Missed Activation Functions for Grouped 2D/3D Convolutions (#1348) * Initial Push * First Push * Fixed Clang format * Resolve merge conflict * Addressed review comments * Addressed review comments * Addressed review comments --- example/62_convnd_activ/unary/CMakeLists.txt | 10 ++++++++++ .../unary/convnd_fwd_xdl_logistic_fp16.cpp | 11 +++++++++++ .../unary/convnd_fwd_xdl_passthrough_fp16.cpp | 11 +++++++++++ .../unary/convnd_fwd_xdl_swish_fp16.cpp | 11 +++++++++++ .../element/unary_element_wise_operation.hpp | 18 ++++++++++++++++++ 5 files changed, 61 insertions(+) create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt index 94ffb3661..3470e9b94 100644 --- a/example/62_convnd_activ/unary/CMakeLists.txt +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -30,6 +30,16 @@ foreach(gpu IN LISTS GPU_TARGETS) # Elu add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16) + # Swish + add_example_executable(example_convnd_fwd_xdl_swish_fp16 convnd_fwd_xdl_swish_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_swish_fp16) + # PassThrough + add_example_executable(example_convnd_fwd_xdl_passthrough_fp16 convnd_fwd_xdl_passthrough_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_passthrough_fp16) + # Logistic + add_example_executable(example_convnd_fwd_xdl_logistic_fp16 convnd_fwd_xdl_logistic_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_logistic_fp16) + set(target 1) endif() endforeach() diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp new file mode 100644 index 000000000..86811c2e9 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Logistic; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp new file mode 100644 index 000000000..7167c4a84 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp new file mode 100644 index 000000000..65a2a5023 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 28514fb78..75429554a 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -961,6 +961,24 @@ struct Elu const float alpha_; }; +struct Logistic +{ + Logistic(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + const float alpha_; +}; + struct ConvInvscale { __host__ __device__ ConvInvscale(float scale_in = 1.f, -- GitLab From 1da802bdf2cc56defeaacf73c63461fe3c4cf692 Mon Sep 17 00:00:00 2001 From: Dan Yao Date: Thu, 20 Jun 2024 22:50:53 +0800 Subject: [PATCH 60/96] Fix FA bwd alibi+causal NaN errors (#1352) * fix bwd alibi nan error * fix datatype --------- Co-authored-by: danyao12 --- include/ck_tile/ops/fmha/block/block_masking.hpp | 2 +- .../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index f43de4573..ce8493663 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -372,7 +372,7 @@ struct SimplifiedGenericAttentionMask // index_t x_end = min(i_y + x, x_total); bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > (i_x + y); + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now return top_right_edge || bottom_left_edge; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index a013ee3d5..d867772a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -501,9 +501,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() { - using QGradDataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; constexpr index_t Banks = 32; // TODO: need change based on arch - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QGradDataType); + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType); constexpr index_t kKPack = GetSmemKPackOGrad(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = [&]() { -- GitLab From 510325a46898dfc6de8187ca806ed3565c0a22e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 21 Jun 2024 09:47:58 +0200 Subject: [PATCH 61/96] Fix cmake warnings (#1342) * Cmake add -Wno-nvcc-compt * Remove template without initialization list * dpp remove template without init list * Fixes --- cmake/EnableCompilerWarnings.cmake | 3 +- .../gpu/block/blockwise_gemm_dpp.hpp | 8 ++-- .../block/blockwise_gemm_pipeline_xdlops.hpp | 41 ++++++++----------- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 15 ++++--- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 28 ++++++------- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 11 +++-- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 25 +++++------ .../blockwise_gemm_pipeline_xdlops_v5.hpp | 25 +++++------ .../gpu/block/blockwise_gemm_wmma.hpp | 16 ++++---- .../gpu/block/blockwise_gemm_xdlops.hpp | 23 +++++------ .../blockwise_gemm_xdlops_skip_b_lds.hpp | 9 ++-- ..._grouped_convnd_fwd_multi_ab_interface.cpp | 10 ++--- 12 files changed, 97 insertions(+), 117 deletions(-) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 8654170b3..fb2b38d68 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -2,7 +2,7 @@ # # MIT License # -# Copyright (c) 2017 Advanced Micro Devices, Inc. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -96,6 +96,7 @@ else() -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture + -Wno-nvcc-compat ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index d62ed4b15..f03427a7e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -300,9 +300,9 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - dpp_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + dpp_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 5d137e67e..1121cc455 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -613,7 +613,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( + xdlops_gemm.Run( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -681,7 +681,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( + xdlops_gemm.Run( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -749,10 +749,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -808,10 +807,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -840,10 +838,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -901,10 +898,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -939,10 +935,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 7dd86468b..f597573dc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -259,7 +259,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -319,10 +319,9 @@ struct BlockwiseGemmXdlops_pipeline_v1(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -584,7 +583,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -668,7 +667,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index dad643ffa..711c47854 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -303,7 +303,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -374,7 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -428,10 +428,9 @@ struct BlockwiseGemmXdlops_pipeline_v2(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -480,10 +479,9 @@ struct BlockwiseGemmXdlops_pipeline_v2(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -821,7 +819,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -914,7 +912,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -990,7 +988,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -1066,7 +1064,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 52f48d0e4..d47318dd0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -381,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_v3(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -440,10 +440,9 @@ struct BlockwiseGemmXdlops_pipeline_v3(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 51ce8ae61..bd5a1bedf 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 8569b680e..b6a4f0550 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run( a_block_desc_m0_m1_m2_k, @@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run( @@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index f8ee283c6..873539f8b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -352,10 +352,9 @@ struct BlockwiseGemmWMMA constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -411,10 +410,9 @@ struct BlockwiseGemmWMMA constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e5e6245cb..e2296a55f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -530,10 +529,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { __builtin_amdgcn_sched_barrier(0); @@ -963,10 +961,9 @@ struct BlockwiseGemmXdlops_v2 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 8ae1ba3f3..287c6701c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp index c529a6a61..346f04f66 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -207,7 +207,7 @@ TEST_F(TestGroupedConvndFwdMultiAInterface, MultiA) std::array as{nullptr, nullptr}; const void* b = nullptr; - EXPECT_TRUE(this->template Run(as, b)); + EXPECT_TRUE(this->Run(as, b)); } TEST_F(TestGroupedConvndFwdMultiBInterface, MultiB) @@ -215,7 +215,7 @@ TEST_F(TestGroupedConvndFwdMultiBInterface, MultiB) const void* a = nullptr; std::array bs{nullptr, nullptr}; - EXPECT_TRUE(this->template Run(a, bs)); + EXPECT_TRUE(this->Run(a, bs)); } TEST_F(TestGroupedConvndFwdMultiABInterface, MultiAB) @@ -223,7 +223,7 @@ TEST_F(TestGroupedConvndFwdMultiABInterface, MultiAB) std::array as{nullptr, nullptr}; std::array bs{nullptr, nullptr}; - EXPECT_TRUE(this->template Run(as, bs)); + EXPECT_TRUE(this->Run(as, bs)); } TEST_F(TestGroupedConvndFwdInterface, SingleAB) @@ -231,5 +231,5 @@ TEST_F(TestGroupedConvndFwdInterface, SingleAB) const void* a = nullptr; const void* b = nullptr; - EXPECT_TRUE(this->template Run(a, b)); + EXPECT_TRUE(this->Run(a, b)); } -- GitLab From fa129c1a5db62354c4b39857d2b1598bb618f8ce Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 22 Jun 2024 00:00:13 +0800 Subject: [PATCH 62/96] WA for rocm-6.2+ s constrait for buffer resource (#1346) * WA for rocm-6.2+ s constrait for buffer resource * add missing memory clobber --- include/ck/utility/amd_buffer_addressing.hpp | 3 ++- include/ck_tile/core/arch/amd_buffer_addressing.hpp | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index cfa4cabee..ab22134fc 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), - "s"(src_resource)); + "s"(src_resource) + : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 13e92ef0b..2cd8bb5f0 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; - return __builtin_bit_cast(int32x4_t, res); + int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); + return r; } namespace impl { @@ -2104,7 +2109,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), - "s"(src_resource)); + "s"(src_resource) + : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = -- GitLab From 05b10e0e5a3ede297a468f104b75abf1056acf7e Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Fri, 21 Jun 2024 19:02:57 -0600 Subject: [PATCH 63/96] Add instances of grouped convolution 3d forward with a ConvScale element-wise op for bf8@bf8->fp8 (#1326) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We are adding more instances of grouped convolution 3d forward with a ConvScale element-wise operation. This commit handles bf8@bf8->fp8 data types combination. * Included an example. * Added instances. * Added a client example. --------- Co-authored-by: Rostyslav Geyyer Co-authored-by: Bartłomiej Kocot --- .../24_grouped_conv_activation/CMakeLists.txt | 7 +- .../conv3d_fwd_convscale_bf8.cpp | 50 +++++++++++ client_example/CMakeLists.txt | 37 +++----- .../62_convnd_activ/convscale/CMakeLists.txt | 17 ++-- .../convnd_fwd_xdl_convscale_bf8.cpp | 88 +++++++++++++++++++ ...ped_conv_fwd_xdl_outelementop_instance.hpp | 37 ++++++++ .../grouped_convolution_forward_convscale.hpp | 28 +++++- .../CMakeLists.txt | 1 + ...cale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp | 62 +++++++++++++ 9 files changed, 295 insertions(+), 32 deletions(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp create mode 100644 example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index b0c895d8a..a624302db 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -40,9 +40,14 @@ add_executable(client_conv3d_fwd_convinvscale_fp8 grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp) target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations) # Fwd convscale -add_executable(client_conv3d_fwd_convscale_fp8 +add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations) + add_executable(client_conv3d_fwd_convscale_fp8_bf8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp) target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp new file mode 100644 index 000000000..f901d08ab --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 8eb662d28..d2222a840 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -6,46 +6,36 @@ if (DTYPES) add_definitions(-DDTYPES) if (DTYPES MATCHES "int8") add_definitions(-DCK_ENABLE_INT8) - if(NOT DEFINED ${CK_ENABLE_INT8}) - set(CK_ENABLE_INT8 "ON") - endif() + set(CK_ENABLE_INT8 "ON") endif() if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) - if(NOT DEFINED ${CK_ENABLE_FP8}) - set(CK_ENABLE_FP8 "ON") - endif() + set(CK_ENABLE_FP8 "ON") + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) - if(NOT DEFINED ${CK_ENABLE_FP16}) - set(CK_ENABLE_FP16 "ON") - endif() + set(CK_ENABLE_FP16 "ON") endif() if (DTYPES MATCHES "fp32") add_definitions(-DCK_ENABLE_FP32) - if(NOT DEFINED ${CK_ENABLE_FP32}) - set(CK_ENABLE_FP32 "ON") - endif() + set(CK_ENABLE_FP32 "ON") endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) - if(NOT DEFINED ${CK_ENABLE_FP64}) - set(CK_ENABLE_FP64 "ON") - endif() + set(CK_ENABLE_FP64 "ON") endif() if (DTYPES MATCHES "bf16") add_definitions(-DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_BF16}) - set(CK_ENABLE_BF16 "ON") - endif() + set(CK_ENABLE_BF16 "ON") endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES}) - set(CK_ENABLE_ALL_DTYPES "ON") - endif() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + set(CK_ENABLE_ALL_DTYPES "ON") endif() if (GPU_TARGETS) @@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}") # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")) + IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build") + AND (NOT "${subdir}" MATCHES ".vscode")) add_subdirectory(${subdir}) ENDIF() ENDFOREACH() diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index d6abb32f2..3de1aff67 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -2,11 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convscale) - add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8) - add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) - set(target 1) + add_custom_target(example_convnd_activ_xdl_convscale) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 ) + + add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8) + + add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) + + set(target 1) endif() endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp new file mode 100644 index 000000000..c1c8c3a57 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp index 6fbbaca7b..0576873b8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp @@ -73,6 +73,43 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_outelementop_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute Type| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> +#endif + // clang-format on + >; + template >>& instances); #endif -#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8>>>& instances); + void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( std::vector && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + op_ptrs); + } + if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt index aef9c10c2..b4cfd1a23 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -1,6 +1,7 @@ # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp new file mode 100644 index 000000000..52cda2ea9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8>>>& instances) +{ + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From cb13839425e0ec4dfff5b8138104ee3e3183d050 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 24 Jun 2024 08:45:52 +0800 Subject: [PATCH 64/96] layernorm2d forward (#1339) * Add layernorm2d forward * Refind file path * clang format * Exclude ck_tile op from all * use add_executable instead * refactor layernorm2d_fwd example --------- Co-authored-by: carlushuang --- example/ck_tile/02_layernorm2d/CMakeLists.txt | 4 + example/ck_tile/02_layernorm2d/README.md | 22 ++ .../02_layernorm2d/layernorm2d_fwd.cpp | 191 ++++++++++++ .../02_layernorm2d/layernorm2d_fwd.hpp | 30 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/core.hpp | 1 + include/ck_tile/core/numeric/null_type.hpp | 13 + include/ck_tile/host.hpp | 1 + include/ck_tile/host/check_err.hpp | 25 +- .../host/reference/reference_layernorm2d.hpp | 69 +++++ include/ck_tile/ops/layernorm2d.hpp | 9 + .../kernel/layernorm2d_fwd_kernel.hpp | 291 ++++++++++++++++++ .../block_layernorm2d_fwd_problem.hpp | 30 ++ .../pipeline/tile_layernorm2d_fwd_shape.hpp | 35 +++ include/ck_tile/ops/welford.hpp | 8 + .../ops/welford/thread/thread_welford.hpp | 101 ++++++ .../ck_tile/ops/welford/warp/warp_welford.hpp | 154 +++++++++ 17 files changed, 975 insertions(+), 10 deletions(-) create mode 100644 example/ck_tile/02_layernorm2d/CMakeLists.txt create mode 100644 example/ck_tile/02_layernorm2d/README.md create mode 100644 example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp create mode 100644 example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp create mode 100644 include/ck_tile/core/numeric/null_type.hpp create mode 100644 include/ck_tile/host/reference/reference_layernorm2d.hpp create mode 100644 include/ck_tile/ops/layernorm2d.hpp create mode 100644 include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp create mode 100644 include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp create mode 100644 include/ck_tile/ops/welford.hpp create mode 100644 include/ck_tile/ops/welford/thread/thread_welford.hpp create mode 100644 include/ck_tile/ops/welford/warp/warp_welford.hpp diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt new file mode 100644 index 000000000..bac5f45cd --- /dev/null +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -0,0 +1,4 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md new file mode 100644 index 000000000..433dad04e --- /dev/null +++ b/example/ck_tile/02_layernorm2d/README.md @@ -0,0 +1,22 @@ +# Layernorm2D forward + +This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_layernorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_layernorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp new file mode 100644 index 000000000..9cbd28610 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -0,0 +1,191 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include + +// Host API implementation +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + using thread_tile = ck_tile::sequence<4, 4>; + using warp_tile = ck_tile::sequence<8, 128>; + using block_tile = ck_tile::sequence<32, 128>; + + using Shape = ck_tile::TileLayernorm2dShape; + + using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; + + using Kernel = ck_tile::Layernorm2dFwd; + + auto kargs = Kernel::MakeKargs( + a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); + + const dim3 grids = Kernel::GridSize(a.M); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "m dimension") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +int main(int argc, char* argv[]) +{ + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({M, N}); + ck_tile::HostTensor gamma_host({N}); + ck_tile::HostTensor beta_host({N}); + + ck_tile::HostTensor y_host_ref({M, N}); + ck_tile::HostTensor y_host_dev({M, N}); + + ck_tile::HostTensor mean_host_ref({M}); + ck_tile::HostTensor invStd_host_ref({M}); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::HostTensor mean_host_dev({M}); + ck_tile::HostTensor invStd_host_dev({M}); +#endif + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); +#endif + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + + layernorm2d_fwd_traits traits{data_type}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + mean_buf.GetDeviceBuffer(), + invStd_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + epsilon, + M, + N}; + + float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << "[" << data_type << "]" + << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" + << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + pass = ck_tile::check_err(y_host_dev, y_host_ref); + +#ifdef SAVE_MEAN_INV_STD + mean_buf.FromDevice(mean_host_dev.data()); + pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + + invStd_buf.FromDevice(invStd_host_dev.data()); + pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); +#endif + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl << std::flush; + + return !pass; +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp new file mode 100644 index 000000000..4d1aac099 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/layernorm2d.hpp" +#include + +struct layernorm2d_fwd_traits +{ + std::string data_type; +}; + +struct layernorm2d_fwd_args +{ + const void* p_x; + const void* p_gamma; + const void* p_beta; + void* p_y; + void* p_mean; + void* p_invStd; + float epsilon; + ck_tile::index_t M; + ck_tile::index_t N; +}; + +// host API +float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d2b086e04..995d193f1 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(AFTER ) add_subdirectory(01_fmha) +add_subdirectory(02_layernorm2d) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb490cce4..4cddf6faa 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/numeric/null_type.hpp b/include/ck_tile/core/numeric/null_type.hpp new file mode 100644 index 000000000..8799c0560 --- /dev/null +++ b/include/ck_tile/core/numeric/null_type.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +struct null_type +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 09030fa6d..0e69a925d 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -18,6 +18,7 @@ #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" +#include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef9b2413..529bfdff2 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -56,8 +56,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -114,8 +115,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -173,8 +175,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; diff --git a/include/ck_tile/host/reference/reference_layernorm2d.hpp b/include/ck_tile/host/reference/reference_layernorm2d.hpp new file mode 100644 index 000000000..837f52c39 --- /dev/null +++ b/include/ck_tile/host/reference/reference_layernorm2d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_layernorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + const HostTensor& beta_n, + HostTensor& y_m_n, + HostTensor& mean_m, + HostTensor& invStd_m, + ComputeDataType epsilon) +{ + auto layernorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + int count = 0; + ComputeDataType mean = 0; + ComputeDataType variance = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ++count; + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType delta = x - mean; + mean += delta / count; + ComputeDataType delta2 = x - mean; + variance += delta * delta2; + } + + // actual variance + variance = variance / count; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(variance + epsilon); + + if constexpr(!std::is_same_v) + mean_m(m) = ck_tile::type_convert(mean); + + if constexpr(!std::is_same_v) + invStd_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + ComputeDataType beta = ck_tile::type_convert(beta_n(n)); + auto y = (x - mean) * divisor; + y = y * gamma + beta; + + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(layernorm2d_fwd_func, + mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp new file mode 100644 index 000000000..3b66645ed --- /dev/null +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp new file mode 100644 index 000000000..4be3e5687 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" + +namespace ck_tile { + +// TODO: Extract some type to wrapper class +template +struct Layernorm2dFwd +{ + using Problem = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = !std::is_same_v; + static constexpr bool kSaveInvStd = !std::is_same_v; + + static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + + static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + const void* p_beta; + + void* p_y; + void* p_mean; + void* p_invStd; + + float epsilon; + + ck_tile::index_t M; + ck_tile::index_t N; + }; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_mean, + void* p_invStd, + float epsilon, + ck_tile::index_t M, + ck_tile::index_t N) + { + return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } + + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + { + constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + + using Lengths = decltype(nDstrSpan.impl_); + + ck_tile::index_t ret = 1; + + ck_tile::static_for<0, Lengths::size(), 1>{}( + [&](auto idx) { ret *= Lengths::template at(idx); }); + + return ret; + } + + template + CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, + const ComputeDataType epsilon) + { + // TODO: Investigate fast inverse square root algorithm with epsilon + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + DistributedTensor out_dstr_tensor; + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + out_dstr_tensor(i_idx) = type_convert(1.0f) / + ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); + }); + + return out_dstr_tensor; + } + + template + CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + MeanDataType* p_mean, + InvStdDataType* p_invStd, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + const auto gamma_n = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto beta_n = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); + + // TODO: padding - handle max_count if N % kNPerBlock != 0 + constexpr auto NPerThread = GetNPerThread(xDstr); + ThreadWelford thread_welford{ + type_convert(NPerThread * N / kNPerBlock)}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + move_tile_window(x_block_window, {0, kNPerBlock}); + } + + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + { + const auto mean_m = make_naive_tensor_view_packed( + p_mean, make_tuple(M), number<32>{}); + + auto mean_block_window = + make_tile_window(mean_m, make_tuple(number{}), {iM}); + + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + } + if constexpr(kSaveInvStd) + { + const auto inv_std_m = make_naive_tensor_view_packed( + p_invStd, make_tuple(M), number<32>{}); + + auto inv_std_block_window = + make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + + store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); + } + + // TODO: Extract normalize pipeline + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {stride_to_right_most_window}); + move_tile_window(beta_block_window, {stride_to_right_most_window}); + move_tile_window(y_block_window, {0, stride_to_right_most_window}); + + // Normalization + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {-kNPerBlock}); + move_tile_window(beta_block_window, {-kNPerBlock}); + move_tile_window(y_block_window, {0, -kNPerBlock}); + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + TwoPassLayernorm2dFwd(static_cast(kargs.p_x), + static_cast(kargs.p_gamma), + static_cast(kargs.p_beta), + static_cast(kargs.p_y), + static_cast(kargs.p_mean), + static_cast(kargs.p_invStd), + static_cast(kargs.epsilon), + kargs.M, + kargs.N); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp new file mode 100644 index 000000000..5206d36d7 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct BlockLayernorm2dFwdProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp new file mode 100644 index 000000000..1ff541d84 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template // Sequence<... +struct TileLayernorm2dShape +{ + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; + + // TODO - kNNumWarps can only be 1 if we don't support cross warp welford + static_assert(kNWarpPerBlock == 1); + + static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp new file mode 100644 index 000000000..dffaad750 --- /dev/null +++ b/include/ck_tile/ops/welford.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp new file mode 100644 index 000000000..2ca9a2365 --- /dev/null +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct ThreadWelford +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE void Update(T& mean, T& var, T x) + { + if(ck_tile::isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + // [CAUSION] - max_count_ is to deal with the padding problem + // max_count_ is depend on caller, eg: naive and splitN welford will have different + // calculation of max_count_ + CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {} + + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x); + }); + } + }); + } + + template + CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor() + { + static_assert(std::is_same_v, "wrong!"); + + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + clear_tile(tensor); + + return tensor; + } + + template + CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor) + { + auto mean_tensor = MakeInitialMeanVarDistributedTensor(); + auto var_tensor = MakeInitialMeanVarDistributedTensor(); + + (*this)(x_tensor, mean_tensor, var_tensor); + + return ck_tile::make_tuple(mean_tensor, var_tensor); + } + + int cur_count_; + int max_count_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp new file mode 100644 index 000000000..687b61f43 --- /dev/null +++ b/include/ck_tile/ops/welford/warp/warp_welford.hpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct WarpMergeWelford +{ + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE static void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a_ * count_b_over_count; + count_a = count; + } + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + const auto rs_idx = + mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); + + // welford merge + Merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + } + }); + + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + if constexpr(BroadcastLane) + { + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; + + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); + + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); + + // decide whether to update local data with remote data + v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; + v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; + v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; + }); + } + }); + } + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + + if constexpr(GetActualVariance) + var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count; + else + var_tensor.get_thread_buffer()(i) = v_local_var; + + count = v_local_count; + }); + } +}; + +} // namespace ck_tile -- GitLab From 3e9711f0cb1c7ffd3826a93dfa6dd65e98715636 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Tue, 25 Jun 2024 14:37:35 -0700 Subject: [PATCH 65/96] CK Instance Gen (#1145) * Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * fixing block to etile error * Format * Updates * Format * fixed rrr layout * generating a sample JSON file: currently contains includes, prologue/epilogue and instances * version where the json is passed into the instances to generate a key * updated run function to just launch kernel * updated run function: only contains kernel object, json file is updated but still needs to be cleaned up, added front-end API to parse JSON into character buffer * adding in testing files * cleaned up comments, still need to work on including header files * removed unneeded files * removed/commented out JSON implementation * added fusion(prologue/epilogue) into instance generation * working on instance selection * added instance selection, need to fix instance validation * removed block2etile map validity check for testing purposes * test running: failing due to incorrect files/input * all grid descs/ptrs completed, but device file not found * Update test and embed modules * Restore older version * added convolution operation, written test, debugging generated code for compilation * attempting to include CK in host directory: _Float16 error * CK header file issues * slight fix * don't crash when hip can't report total memory * dump generated code to a file * changing sizes * creating tensor descriptors using CK methods: set up grid desc manually, also trying to set up an argument pointer - this needs to be fixed * some fixes to call the device code * separating test files for conv and gemm * completed arg ptr, now have linking errors * clang format fix * resolved linker issues in conv test * remove dependency on libutility from ck * resolved num dim error * properly passing arg ptr, errors with passing typenames: redefinition/redeclaration * undo the commenting of device function * hand created kernel code to find rtc issues * dump the full src to file * resolved redeclaration errors, cleaned up errors for Amber's kernel code * debugging purposes: redeclaration error * config files * resolved errors for NumTensor and redeclaration, formatted version.h * resolved most errors in manually added kernel and my own. error with calling kernel object: overloaded function type * WIP: close to getting kernel compiled * WIP: fixing rtc errors * fixed sequence errors, formatting, still one error with run fcn * yay: kernel compiles and runs * updated templated/generated version to run and compile * minor fixes * working generated example, resolved memory access error due to padding * adding in reference kernel, validation failing against reference * debugging: printing kernel argsz * reduced error in results * debugged reference kernel and output errors, added to generated version, currently debugging prologue function issues * working validation (using reference convolution) with prologue function for both hard-coded and generated version * WIP: create an alt version that creates Argument on the device * wip: added new duplicate files, fixed fusion templating errors from working example, setting up kernel arguments * wip: making necessary methods device code * added grid descs, working on grid pointers, errors with stl numerics * wip: updating kernel args - issue, replacing some std functions * replaced std::accumulate call with temp hardcoded version * wip: args causing memory issue * Construct Argument object inside the kernel and use it to call convolution device function. Code runs and verification passes * adding object file dump * temporary hardcoding of grid size, can remove device op inst + arg ptr * minor fix for grid size * added modified example where arg ptr is created on the device for generated version as well * removed device op instance and arg ptr from modified examples * moving device op file for testing purposes and to properly build CK * commenting out print-outs * adjust compiler args to produce a valid ELF file * temporary removal of validation * reverting compiler args back for working example * retrieve necessary arguments from generated template parameters in correct format * calculating grid size on host-side, still need to clean up process, pass parameters to host functions properly * scaled up factory functions/wrapper structs to implement host-side launch parameter calculations using CK host side functions - in hard-coded example * temporary change to generate ELF format binary object file * removed unecessary code, added comments * formatting fix * cleaned up code, added new tests, restructured library: move helper into CK * refactored launch parameter calculation to be more concise * renamed files and variables for more clarity/uniformity * more code cleaning, removed debug statements * moved majority of my files into codegen directory, running properly * updated Embed.cmake(string_view) in codegen directory * updated host directory to match Embed.cmake as well * added old tests in * updated instance generation methods to be more concise * removed layout from launch parameter calculation * working test * fixed issue with verification, all instances working * updated verification in other tests * removed duplicate matrix padder file, removed code dumps * removed old hard-coded tests * removed old host directory, all files in codegen directory now * fixed copyright in files * commenting out validation * renamed files * made changes for review: fixed copyright, renamed files for clarity, removed comments, refactored code * updated headers * removing duplicate file for fwd conv to gemm, merging with original file * fix building codegen with clang++ directly * resolving build error from conv_fwd_to_gemm * fix for previous error * renaming tests * created common test file * cleaned up code, added comments * renamed device op * fixed typos in comments * removed extra space * code cleanup: resolving Amber's comments * removed wrapper struct for matrix padder, fixed template * cleaned up if statements for better readability --------- Co-authored-by: Paul Co-authored-by: Jing Zhang Co-authored-by: M. Amber Hassaan Co-authored-by: illsilin Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- codegen/CMakeLists.txt | 28 +- codegen/driver/main.cpp | 42 +- .../ck/host/device_gemm_multiple_d.hpp | 2 +- .../host/device_gemm_multiple_d/operation.hpp | 17 +- .../host/device_gemm_multiple_d/problem.hpp | 17 +- .../conv_fwd_op.hpp | 60 ++ .../conv_fwd_problem.hpp | 56 ++ codegen/include/ck/host/headers.hpp | 1 - codegen/include/ck/host/operation/gemm.hpp | 2 +- codegen/include/ck/host/stringutils.hpp | 2 +- codegen/include/ck/host/types.hpp | 18 +- codegen/include/ck/host/utils.hpp | 5 +- codegen/src/device_gemm_multiple_d.cpp | 15 +- ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 57 +- .../device_grouped_conv_fwd_multiple_abd.cpp | 42 + ...wd_multiple_abd_operation_xdl_cshuffle.cpp | 364 ++++++++ codegen/src/headers.cpp | 2 +- codegen/src/types.cpp | 8 + codegen/src/utils.cpp | 2 +- codegen/test/CMakeLists.txt | 14 +- codegen/test/common.hpp | 134 +++ codegen/test/gemm_multiple_d.cpp | 7 +- .../test/grouped_conv_fwd_multiple_d_v1.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v2.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v3.cpp | 209 +++++ .../test/grouped_conv_fwd_multiple_d_v4.cpp | 209 +++++ codegen/test/rtc/src/compile_kernel.cpp | 8 + codegen/test/rtc/src/hip.cpp | 6 +- .../ck/tensor_operation/gpu/device/helper.hpp | 359 ++++++++ ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 781 ++++++++++++++++++ .../gpu/device/matrix_padder.hpp | 13 + .../transform_conv_fwd_to_gemm.hpp | 564 +++++++++++++ include/ck/utility/array.hpp | 2 + 33 files changed, 3417 insertions(+), 47 deletions(-) create mode 100644 codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp create mode 100644 codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp create mode 100644 codegen/src/device_grouped_conv_fwd_multiple_abd.cpp create mode 100644 codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp create mode 100644 codegen/test/common.hpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v1.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v2.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v3.cpp create mode 100644 codegen/test/grouped_conv_fwd_multiple_d_v4.cpp create mode 100644 include/ck/tensor_operation/gpu/device/helper.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 72549c9a4..d8b22fc94 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.16) -project(composable_kernel_host) +project(composable_kernel_host LANGUAGES CXX HIP) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -12,24 +12,38 @@ find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +add_compile_options(-std=c++17) +find_package(hip) +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") + +# add include directories +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include + ${HIP_INCLUDE_DIRS} + ) + list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) + ${CK_ROOT}/include/ck/*.hpp) message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "RELATIVE: ${CK_ROOT}/include") add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_definitions(-std=c++17) - file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) target_include_directories(ck_host PUBLIC $ diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index dfd513106..c7d295de9 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -5,24 +5,27 @@ #include #include #include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/stringutils.hpp" using ck::host::Transform; struct Emitters { + // retrieve the hard-coded instances provided, template them, and then store them in a map std::unordered_map()>> m; template - void Register(const std::string& name) + void Register(const std::string& name, const std::string& prologue, const std::string& epilogue) { - m[name] = [] { - auto configs = T::CreateOperations(); + m[name] = [&] { + auto configs = T::CreateOperations(prologue, epilogue); return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); }; } + // takes in an operation instance and uses it to substitute the correct values into the template template static std::string ToTuple(const T& ops) { @@ -31,6 +34,7 @@ struct Emitters return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; } + // Join together all the strings in the map std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } std::vector List() const @@ -43,9 +47,38 @@ int main(int argc, const char* argv[]) { std::string prog = argv[0]; std::vector args(argv + 1, argv + argc); + + // Specify problem type and problem size + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + + // user provided fusion + std::string prologue = ""; + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};)"; + + // Load in operations into the Register Emitters e; e.Register( - "DeviceGemmMultipleD_Xdl_CShuffle"); + "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue); if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { return arg == "-h" or arg == "--help"; @@ -64,6 +97,7 @@ int main(int argc, const char* argv[]) return 0; } + // print out all the instances for the operation that was chosen at the command line for(auto name : args) std::cout << e.Emit(name) << std::endl; diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp index 88e040db5..02c19c88e 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index f9d39633a..359da7d8c 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -14,10 +14,15 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines all values need for an instance of fwd conv struct Operation_Xdl_CShuffle { - static std::vector> CreateOperations(); - static std::vector CreateOperations(const Problem& prob); + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); TensorDesc A{}; TensorDesc B{}; DataType acc = DataType::Float; @@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle std::string a_elem_op = PassThrough; std::string b_elem_op = PassThrough; std::string cde_elem_op = Bilinear; + std::string prologue = ""; + std::string epilogue = ""; std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters operation::TileDesc tile_desc{}; operation::BlockTransferDesc a_block_transfer{}; operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); + // returns a templated instance Solution ToSolution() const; }; diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp index f6dbc2b6e..f4036328e 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,11 +12,14 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines the problem specification for a GEMM operation struct Problem { - std::size_t M = 0; - std::size_t N = 0; - std::size_t K = 0; + // dimensions for GEMM operation + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + // layouts for tensors bool TransA = false; bool TransB = false; bool TransE = false; @@ -29,9 +32,13 @@ struct Problem std::string BElementOp = PassThrough; std::string CDEElementOp = PassThrough; + // returns the correct device op file for the operation std::string GetIncludeHeader() const; - std::vector GetSolutions(const std::string& arch) const; + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; }; } // namespace device_gemm_multiple_d diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp new file mode 100644 index 000000000..5ad1dce17 --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the values needed for an instance of forward convolution and functions to return +// (templated) instances +struct Operation_Conv_Fwd_Xdl_Cshuffle +{ + // returns a vector of instances given the fusion operations, uses default values for problem + // spec + static std::vector + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, provided with a problem spec and fusion operations + static std::vector CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue); + std::size_t NumDim; + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = PassThrough; + std::string prologue = ""; + std::string epilogue = ""; + std::string conv_specialization = + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"; + std::string gemm_specialization = + "ck::tensor_operation::device::GemmSpecialization::MNKPadding"; + // tuning parameters + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + // functions to update fusion operations if they are provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp new file mode 100644 index 000000000..433f9a8fc --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the problem specification for a forward convolution operation +struct Problem_Conv_Fwd +{ + std::size_t NumDim = 0; + // size of a forward convolution operation + std::size_t G = 0; + std::size_t N = 0; + std::size_t C = 0; + std::size_t Hi = 0; + std::size_t Wi = 0; + std::size_t Ho = 0; + std::size_t Wo = 0; + std::size_t K = 0; + std::size_t Y = 0; + std::size_t X = 0; + Layout ALayout = Layout::NHWGC; + Layout BLayout = Layout::GKYXC; + Layout ELayout = Layout::NHWGK; + std::vector DsLayout = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough"; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 3da05baaa..54f8d9f73 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include #include diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index f587122b0..84ef92f0a 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 01374b86c..89c1884d2 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 23488a66d..812c07367 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,6 +12,7 @@ namespace ck { namespace host { +// holds the templated instance, substitues values into template from instancess struct Solution { @@ -33,6 +34,7 @@ struct Solution std::unordered_map template_values; }; +// supported data types enum class DataType { Half, @@ -40,22 +42,28 @@ enum class DataType Int8, Int32 }; - std::string ToString(DataType dt); +// supported layouts: gemm and fwd conv enum class Layout { Row, - Column + Column, + GKYXC, + GKCYX, + GNHWK, + GNHWC, + NHWGC, + NHWGK }; - std::string ToString(Layout dl); +Layout ToLayout(bool Trans); // returns the layout for gemm +// supported GEMM types enum class GemmType { Default }; - std::string ToString(GemmType gt); struct TensorDesc diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp index e8785a456..21926814f 100644 --- a/codegen/include/ck/host/utils.hpp +++ b/codegen/include/ck/host/utils.hpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include +#include namespace ck { namespace host { @@ -12,6 +14,5 @@ namespace host { std::size_t integer_divide_ceil(std::size_t x, std::size_t y); const std::unordered_set& get_xdlop_archs(); - } // namespace host } // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp index ec25afc0f..44bc051a8 100644 --- a/codegen/src/device_gemm_multiple_d.cpp +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" @@ -11,23 +11,28 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// return the relevant device op file based on the operation std::string Problem::GetIncludeHeader() const { return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; } -std::vector Problem::GetSolutions(const std::string& arch) const +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const { if(get_xdlop_archs().count(arch) == 0) return {}; - auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this); + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances std::vector result; std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { - return op.ToSolution(); + return op.ToSolution(); // template instance with correct values }); return result; } } // namespace device_gemm_multiple_d } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index 9e397497e..a2e8eccbf 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -10,6 +10,7 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// calculate appropriate Gemm Specification based on input tensor dimensions static std::string GetGemmSpec(const std::size_t m, const std::size_t n, const std::size_t k, @@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m, return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; } +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } -std::vector Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) { std::vector result; @@ -155,6 +187,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con // clang-format on }; + // choose correct arrangement of tuning parameters based on the layout of each tensor const auto a_block_descriptions = prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; const auto b_block_descriptions = @@ -165,6 +198,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); + // Put all values together into a single operation > store into the result vector for(std::size_t i = 0; i < tile_descriptions.size(); i++) { Operation_Xdl_CShuffle x; @@ -188,12 +222,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con x.tile_desc.m_per_block, x.tile_desc.n_per_block, x.tile_desc.k_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); result.push_back(x); } return result; } -std::vector> Operation_Xdl_CShuffle::CreateOperations() +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) { std::vector problems; for(bool TransA : {true, false}) @@ -204,7 +243,8 @@ std::vector> Operation_Xdl_CShuffle::CreateO prob.TransB = TransB; problems.push_back(prob); } - return Transform(problems, [](const Problem& p) { return CreateOperations(p); }); + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); } static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = @@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferScalarPerVector_NPerBlock}>"; +// use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const { std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, {"LayoutA", ToString(this->A.layout)}, {"LayoutB", ToString(this->B.layout)}, {"LayoutDs", diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp new file mode 100644 index 000000000..c689e5ec9 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/utils.hpp" +#include +#include + +namespace ck { +namespace host { +namespace conv { + +// return the relevant device op file based on the operation +// NOTE: this is a modified version of the original CK file that calls the kernel from a device +// function and makes the Argument class accessible on the device +std::string Problem_Conv_Fwd::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/" + "codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"; +} + +// return vector of forward convolution instances when provided with a problem instance +std::vector Problem_Conv_Fwd::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + *this, prologue, epilogue); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp new file mode 100644 index 000000000..94161a76d --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace conv { + +// calculate appropriate Gemm Specification based on input tensor dimensions +// NOTE: in CK, MNKPadding is always used for forward convolution +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + // clang-format on + }; + + std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector b_block_descriptions = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1} + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8} + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Conv_Fwd_Xdl_Cshuffle x; + x.NumDim = prob.NumDim; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, prob.ALayout}; + x.B = TensorDesc{prob.BDataType, prob.BLayout}; + x.E = TensorDesc{prob.EDataType, prob.ELayout}; + x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { + return TensorDesc{dt, lo}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values +std::vector +Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue, + const std::string& epilogue) +{ + Problem_Conv_Fwd prob; + return CreateOperations(prob, prologue, epilogue); +} + +static const char* const CopyDevice_ConvTemplate = + R"( +${Prologue} +${Epilogue} + +using CDEElementOp = Epilogue; +using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>; + +constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors(); +constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors(); + +extern "C" __global__ void run_${name}( + const ${ADataType}* in_dev, + const ${BDataType}* wei_dev, + ${EDataType}* __restrict__ out_dev, + ck::Array in_lengths, + ck::Array in_strides, + ck::Array wei_lengths, + ck::Array wei_strides, + ck::Array out_lengths, + ck::Array out_strides, + ck::Array conv_filter_strides, + ck::Array conv_filter_dilations, + ck::Array input_left_pads, + ck::Array input_right_pads, + const ${AElementwiseOperation} a_element_op, + const ${BElementwiseOperation} b_element_op, + const ${CDEElementwiseOperation} cde_element_op +){ + + + auto arg = DeviceConv::Argument(in_dev, + wei_dev, + ck::Array{}, + out_dev, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + ck::Array, 0>{}, + ck::Array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + ${AElementwiseOperation}{}, + ${BElementwiseOperation}{}, + ${CDEElementwiseOperation}{1.0f, 1.0f}); + + constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); + + // GridwiseGemm + using GridwiseGemm = DeviceConv::GridwiseGemm; + + static constexpr auto I0 = ck::Number<0>{}; + + ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ${ADataType}*, + const ${BDataType}*, + typename GridwiseGemm::DsGridPointer, + ${EDataType}, + ${AElementwiseOperation}, + ${BElementwiseOperation}, + ${CDEElementwiseOperation}, + DeviceConv::AGridDesc_AK0_M_AK1, + DeviceConv::BGridDesc_BK0_N_BK1, + DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::Block2ETileMap, + ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch, + ck::integral_constant{}, + false, + false> + ( + arg.p_as_grid_.At(I0), + arg.p_bs_grid_.At(I0), + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_ + ); + +} +)"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"NumDim", std::to_string(this->NumDim)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"ComputeDataType", ToString(this->A.element)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"Prologue", this->prologue}, + {"Epilogue", this->epilogue}, + {"ConvSpecialization", this->conv_specialization}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)}; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 6fcb94cdb..f685aca04 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -14,4 +14,4 @@ std::unordered_map GetHeaders() } } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index d43df73f3..a8a8b10c0 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -29,12 +29,20 @@ std::string ToString(DataType dt) throw std::runtime_error("Incorrect data type"); } +Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + std::string ToString(Layout dl) { switch(dl) { case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX"; + case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC"; + case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK"; + case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC"; + case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC"; + case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK"; } throw std::runtime_error("Incorrect layout"); } diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index cd6700c48..19627d4cf 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/utils.hpp" diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 897cce1c9..f89128601 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,11 +1,13 @@ - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) - file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) foreach(TEST_SRC ${TEST_SRCS}) -get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) -rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) -target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) -target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) + # target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a) + target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/include) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) endforeach() diff --git a/codegen/test/common.hpp b/codegen/test/common.hpp new file mode 100644 index 000000000..99d4c6497 --- /dev/null +++ b/codegen/test/common.hpp @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +std::size_t GetSize(V mLens, V mStrides) +{ + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.Size(); ++i) + { + if(mLens[i] == 0) + continue; + + space += (mLens[i] - 1) * mStrides[i]; + } + return space; +} + +template +rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + rtc::buffer result(space); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + // std::fill(result.begin(), result.end(), 1); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 17b659993..bd7ef463f 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using half = _Float16; // using half = __fp16; @@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel) auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto c = to_gpu(generate_buffer(1024 * 1024, 2)); - for(auto solution : prob.GetSolutions("gfx90a")) + std::string epilogue = ""; + std::string prologue = ""; + + for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) { auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, @@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel) auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * ck::host::integer_divide_ceil(prob.N, n_per_block); k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + CHECK(report(solution, check(rtc::from_gpu(c)))); } } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp new file mode 100644 index 000000000..3c477692e --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include "common.hpp" +#include + +// Need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp new file mode 100644 index 000000000..ec9bd2b78 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include + +// need this for validation +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp new file mode 100644 index 000000000..9850184c5 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp new file mode 100644 index 000000000..907f744db --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 7ea55b932..d84ebf4de 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer) } std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } +// TODO: undo after extracting the codeobj +// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } kernel compile_kernel(const std::vector& srcs, compile_options options) { @@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector& srcs, compile_options options auto obj = read_buffer(out_path.string()); + std::ofstream ofh("obj.o", std::ios::binary); + for(auto i : obj) + ofh << i; + ofh.close(); + // int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str()); + // assert(s == 0); return kernel{obj.data(), options.kernel_name}; } diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 10e38c9ad..747f83e3b 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace rtc { @@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory() size_t total; auto status = hipMemGetInfo(&free, &total); if(status != hipSuccess) - throw std::runtime_error("Failed getting available memory: " + hip_error(status)); + { + std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl; + return (8ull * 1024ull * 1024ull * 1024ull); + } return free; } diff --git a/include/ck/tensor_operation/gpu/device/helper.hpp b/include/ck/tensor_operation/gpu/device/helper.hpp new file mode 100644 index 000000000..c52566509 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/helper.hpp @@ -0,0 +1,359 @@ +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include +#include + +// functions to return the corresponding structs based on generated template parameters + +using layouts = std::variant; +// return the layout type: currently this is the only type supported in MIOpen +auto layout_type(std::string type) +{ + if(type == "ck::tensor_layout::convolution::NHWGK") + { + return ck::tensor_layout::convolution::NHWGK{}; + } + throw std::runtime_error("Incorrect layout"); +} +// return the right gemm spec based on the generated template parameters +ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type) +{ + if(type == "ck::tensor_operation::device::GemmSpecialization::Default") + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding") + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + throw std::runtime_error("Incorrect gemm spec: " + type); +} + +// return the type of convolution +ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type) +{ + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + } + if(type == + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + } + throw std::runtime_error("Incorrect conv spec: " + type); +} + +// Function to call on MatrixPadder via a wrapper struct +// NOTE: CK only uses MNKPadding for forward convolution +template +auto pad(ck::index_t mpb, + ck::index_t npb, + ck::index_t kpb, + ck::tensor_operation::device::GemmSpecialization gemm, + CDesc_MRaw_NRaw conv) +{ + if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding) + { + ck::tensor_operation::device::MatrixPadder< + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + ck::index_t, + ck::index_t, + ck::index_t> + a; + a.MPerTile_ = mpb; + a.NPerTile_ = npb; + a.KPerTile_ = kpb; + auto tmp = grid_desc(a, conv); + return tmp; + } + throw std::runtime_error("Incorrect template parameters, check gemm spec"); +} + +// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num +// dims +// FIXME: add a way to properly pass in the layout +auto transform_conv(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_3d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_1d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect dims or conv spec"); +} + +template +auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder) +{ + if(m_per_block == 32 && n_per_block == 64) + { + auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 32 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 256) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 256 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + throw std::runtime_error("Incorrect template parameters"); +} + +// wrapper functions by dims to get grid size - uses above 3 functions +// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions +auto get_launch_params_1d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params_3d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 000000000..7ef4e7f18 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ + + device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AsPointer, // tuples if multi AB, pointers if no + BsPointer, + DsPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfBatch, + HasMainKBlockLoop, + isMultiA, + isMultiB>(p_as_grid, + p_bs_grid, + p_ds_grid, + *p_e_grid, + a_element_op, + b_element_op, + cde_element_op, + batch_count, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + compute_ptr_offset_of_batch); +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument + { + __device__ __host__ Argument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + ck::Array a_g_n_c_wis_lengths_; + ck::Array a_g_n_c_wis_strides_; + ck::Array b_g_k_c_xs_lengths_; + ck::Array b_g_k_c_xs_strides_; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_; + ck::Array e_g_n_k_wos_lengths_; + ck::Array e_g_n_k_wos_strides_; + ck::Array conv_filter_strides_; + ck::Array conv_filter_dilations_; + ck::Array input_left_pads_; + ck::Array input_right_pads_; + }; + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index c66d2fc51..029415314 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder +auto grid_desc(MatrixPadder matrix_padder, + CDesc_MRaw_NRaw conv_desc) +{ + auto res = matrix_padder.PadCDescriptor_M_N(conv_desc); + return res; +} // M/N/KPerTileType could be index_t or Number<> template +__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init) +{ + for(ForwardIterator x = first; x != first + count; x++) + { + init *= *x; + } + return init; +} + template struct TransformConvFwdToGemm { @@ -607,6 +618,559 @@ struct TransformConvFwdToGemm return out_gemmm_gemmn_desc; } + + // Overloaded functions for hipRTC purposes + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Wi = a_g_n_c_wis_lengths[3]; + + const index_t Wo = c_g_n_k_wos_lengths[3]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t WiStride = a_g_n_c_wis_strides[3]; + const auto CStride = I1; + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t HiStride = a_g_n_c_wis_strides[3]; + const index_t WiStride = a_g_n_c_wis_strides[4]; + const auto CStride = I1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmk_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const index_t N = a_g_n_c_wis_lengths[1]; + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + // This is different + const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto in_gemmm_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); + + return in_gemmm_gemmk_desc; + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + // This is different + const index_t NStride = a_g_n_c_wis_strides[1]; + const index_t DiStride = a_g_n_c_wis_strides[3]; + const index_t HiStride = a_g_n_c_wis_strides[4]; + const index_t WiStride = a_g_n_c_wis_strides[5]; + const auto CStride = I1; + + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_gemmm_gemmk_desc; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); + + const auto wei_gemmn_gemmk_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + + return wei_gemmn_gemmk_desc; + } + + template < + typename BLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); + + const index_t KStride = b_g_k_c_xs_strides[1]; + const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + + return out_gemmm_gemmn_desc; + } + + template < + typename CLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + + const auto KStride = I1; + const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); + + return out_gemmm_gemmn_desc; + } + + // for output bias + template , + bool>::type = false> + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) + { + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; + + const index_t NHoWo = + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); + + return out_gemmm_gemmn_desc; + } +}; + +// wrapper class to call member functions on TransformConvToGemm struct at runtime +// TODO: figure out aq way to properly pass in layout as an argument +struct TransformConv +{ + TransformConv() {} + + template + auto + transform_func(ck::Array out_lengths, + ck::Array out_strides, + TransformConvFwdToGemm conv_fwd_to_gemm) + { + if(NDimSpatial == 2) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 3) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 1) + { + return conv_fwd_to_gemm.template MakeCDescriptor_M_N( + out_lengths, out_strides); + } + } }; } // namespace tensor_operation diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index f63ce5e5a..5366c56a9 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -36,6 +36,8 @@ struct Array return *this; } + __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } + __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } }; // empty Array -- GitLab From 0cb2e06ddcdf6b1414cbead4262b76b5d4391e93 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 26 Jun 2024 17:41:15 +0800 Subject: [PATCH 66/96] [CK_TILE] fmha forward split-kv + combine kernels (#1338) * FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm * Add num_splits option and dummy split-kv api method * Generate fmha_fwd_splitkv() * Add SplitKV kernel codegen logics * Add SplitKV combine kernel codegen logics * Fix mismatched return type * Clean-up code * Replace sentinel value before storing * Fix wrong layout of LSE/LSEacc/Oacc * Format codes * Fix o_acc memory error * Fix wrong kBlockSize used in policy * Reduce # of combine kernels * Fix split-kv combine kernel name * Fix wrong LDS indexing logics * Fix wrong loop counter step logic * Undo vector size changes * Remove no-longer used field * Remove in-consistent comment * Remove debug statements in example * Remove more debug statements * Add constness to local variables * Clearn up generate.py * Fix unstable clang-format comment * Remove unused include directive * Use shorter template parameter name * Enable non-split-kv blobs * Update license date * Print num_splits conditionally * Undo disabling data types * Remove unnessary tile size for fp8 * Fix wrong pipeline args for fp8 * Fix example output format * Remove more debug code in combine pipeline * Add stride kernel arguments for LSE/O acc workspace * Re-order split-kv pipeline call operator arguments * Pass LSE/O strides in kernel argument * Re-order pipeline call operator arguments * Use tensor_descriptor to locate LSEacc elements * Support providing invalid element for tensor view * Set invalid element value for LSEacc tensor view * Remove hand-written store_tile() code * Remove necessary value-overwrite logic * Add transposed lds descriptor * Support load_tile() for tile_window_with_static_lengths<> * Undo removing necessary value-overwrite logic * Use read descriptor to locate lds elements * Simplify pipeline source code * Add constraint to kMaxSplits * Default use kMaxSplits=64 in generate.py * Revert "Add constraint to kMaxSplits" This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118. * Revert "Default use kMaxSplits=64 in generate.py" This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3. * Decide alignment by the padding parameter * Remove no-longer used utility functions * Remove not-working code * Add comment & remove no-longer used code * Fix computation errors * Add heuristic to override num_splits option * Add constraint to kMaxSplits * Fix compilation error * Clean up pipeline code * Wrap pointer access as lambda function * Rename confusing methods * Use kLogMasSplits as template parameter * Finish splitkv combine kernel codegen * Update kMaxSplits limit * Use smaller kM0 for splitkv combine kernel * Ignore droupout flag in splitkv pipeline * Unify flag usage * Add back flag kStoreLSE * Merge lambda calls in pipeline * Fix compilation errors * Avoid all empty splits * Always check for empty loop in splitkv pipelines * Re-order parameters * Remove redundant p_drop option check * Add traits/problem for fwd splitkv kernel * Conditionally enable uneven split boundary checks * Add comment for the splitkv traits field * Change even split criteria * Re-order statements * Refine occupancy value for hdim=128&256 * Refine occupancy value for hdim=32&64 * Remove redundant kernel argument * Separate fmha bwd codegen logics * Separate fmha fwd codegen logics * Remove redundant direction parameter in fwd&bwd codegen logics * Support generate multiple APIs for an example * Let 'api' an alias of 'direction' option * Remove choices for the 'direction' option * Use dictionary to config all the functions * Move fmha splitkv codegen logics to other file * Add fwd_splitkv api for tile_example_fmha_fwd --------- Co-authored-by: danyao12 Co-authored-by: carlushuang Co-authored-by: rocking Co-authored-by: Jing Zhang --- example/ck_tile/01_fmha/CMakeLists.txt | 8 +- example/ck_tile/01_fmha/codegen/__init__.py | 0 .../ck_tile/01_fmha/codegen/cmake_config.py | 5 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 92 ++ .../ck_tile/01_fmha/codegen/ops/__init__.py | 0 .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 611 +++++++++ .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 498 +++++++ .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 671 +++++++++ example/ck_tile/01_fmha/fmha_fwd.cpp | 170 ++- example/ck_tile/01_fmha/fmha_fwd.hpp | 215 +++ example/ck_tile/01_fmha/generate.py | 1217 +---------------- include/ck_tile/ops/fmha.hpp | 10 + .../ck_tile/ops/fmha/block/block_masking.hpp | 17 + .../fmha_fwd_splitkv_combine_kernel.hpp | 455 ++++++ ...a_fwd_splitkv_combine_tile_partitioner.hpp | 49 + .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 913 +++++++++++++ .../fmha_fwd_splitkv_tile_partitioner.hpp | 53 + ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 314 +++++ ...plitkv_combine_pipeline_default_policy.hpp | 175 +++ ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 666 +++++++++ ...ha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp | 770 +++++++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 19 + ...litkv_pipeline_qr_ks_vs_default_policy.hpp | 19 + .../pipeline/block_fmha_pipeline_problem.hpp | 65 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 44 + 25 files changed, 5858 insertions(+), 1198 deletions(-) create mode 100644 example/ck_tile/01_fmha/codegen/__init__.py create mode 100644 example/ck_tile/01_fmha/codegen/cmake_config.py create mode 100644 example/ck_tile/01_fmha/codegen/cpp_symbol_map.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/__init__.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e324f85ed..e30e9e793 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,12 +1,12 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt ) # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") diff --git a/example/ck_tile/01_fmha/codegen/__init__.py b/example/ck_tile/01_fmha/codegen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py new file mode 100644 index 000000000..03ebfd670 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py new file mode 100644 index 000000000..d3d215f7f --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +DTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp8" : "ck_tile::fp8_t" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/__init__.py b/example/ck_tile/01_fmha/codegen/ops/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py new file mode 100644 index 000000000..0160915a5 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -0,0 +1,611 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_dk_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, false>>; + +using fmha_bwd_dv_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, false>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_dk_epilogue_{F_idx}, + fmha_bwd_dv_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py new file mode 100644 index 000000000..1486671f6 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -0,0 +1,498 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py new file mode 100644 index 000000000..419fbaaea --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -0,0 +1,671 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdTileSize, + FmhaFwdApiTrait, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_SPLITKV_PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", +} + +FMHA_FWD_SPLITKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_mask_{F_idx} = {F_mask}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape = ck_tile::TileFmhaShape; + +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + kHasUnevenSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait>; + +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if constexpr({F_mode} == false) {{ // batch mode + if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + kernel_runner::run(s, a); + }} else {{ + kernel_runner::run(s, a); + }} + }} else {{ + kernel_runner::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_get_name_() +{{ + using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, + {F_dvpad}, + {F_lse}, + {F_squant}, + kLogMaxSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_hdim}, + {F_bm0}, + {F_bn1}, + {F_mode}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVCombineKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, + {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if (a.num_splits <= 16) {{ + kernel_runner<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + kernel_runner<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + kernel_runner<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + kernel_runner<7>::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API=""" +#include + +template +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); +}} + +float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +@dataclass +class FmhaFwdSplitKVCombinePipeline: + tag : str + + F_spad : str # true/false + F_dvpad : str # + F_lse : str # + F_squant : str # + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}' + if pn != '' : n += f'_{pn}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdSplitKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdSplitKVCombineTileSize: + F_bm0 : int # tile size along q seqlen + F_bn1 : int # tile size along v head_dim + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdSplitKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdSplitKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +@dataclass +class FmhaFwdSplitKVCombineKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdSplitKVCombineTileSize + F_pipeline : FmhaFwdSplitKVCombinePipeline + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn1 = self.F_tile.F_bn1, + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_mode = MODE_MAP[self.F_mode]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + else: + return None + +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # splitkv kernel donot support dropout + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): + if hdim == 256: + # if True: + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVCombinePipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + else: + assert False + return pipelines + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + gen.append(k) + + return gen + +def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME + file_path.write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_splitkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 5f887f065..28f790573 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[]) .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -155,6 +158,106 @@ auto get_elimit(std::string init_method) } } +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + efficiency.push_back(0.f); + } + else + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + const int kN1 = hdim_v; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + } + + return num_splits; +} + +float fmha_fwd_dispatch(fmha_fwd_traits traits, + fmha_fwd_args args, + const ck_tile::stream_config& config) +{ + if(1 < args.num_splits) + { + return fmha_fwd_splitkv(traits, args, config); + } + else + { + return fmha_fwd(traits, args, config); + } +} + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser) seed.reset(); } + int num_splits = arg_parser.get_int("num_splits"); + int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); @@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return false; + } + auto get_lengths = [&](bool permute, ck_tile::index_t b /*batch*/, ck_tile::index_t h /*nhead*/, @@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{batch, nhead}) : std::array{1, 1}); - // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + ck_tile::HostTensor lse_acc_host( + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits + ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // self define lse data layout as [batch, nhead, max_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); @@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); @@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", mask:" << mask << ", v:" << vlayout; + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + std::cout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = hdim_v; const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser) (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; + const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), @@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() : bias_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(), + lse_acc_buf.GetDeviceBuffer(), + o_acc_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser) hdim_v, nhead, nhead_k, + num_splits, scale_s, scale_p, scale_o, @@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, stride_randval, + stride_o_acc, stride_o, nhead_stride_q, nhead_stride_k, @@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser) nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, + nhead_stride_lse_acc, + nhead_stride_o_acc, nhead_stride_o, batch_stride_q, batch_stride_k, @@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_bias, batch_stride_randval, batch_stride_lse, + batch_stride_lse_acc, + batch_stride_o_acc, batch_stride_o, + split_stride_lse_acc, + split_stride_o_acc, mask.left, mask.right, static_cast(mask.type), @@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser) {drop_seed, drop_offset}}; }(); - float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); if(ave_time < 0) { @@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser) lse_host_result.ForEach( [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); - bool lse_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); - pass &= lse_pass; + pass &= cur_pass; if(!cur_pass) { std::cerr << "LSE mismatch found at batch: " << wb << std::endl diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 3594f61db..ee932ce5d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -93,6 +93,8 @@ struct fmha_fwd_args const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -106,6 +108,7 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; float scale_s; float scale_p; float scale_o; @@ -114,6 +117,7 @@ struct fmha_fwd_args ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_randval; + ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; @@ -121,6 +125,8 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -128,7 +134,11 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; @@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -298,3 +512,4 @@ struct fmha_fwd_traits // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index e0b4b6559..27347b447 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -3,1214 +3,62 @@ # generate kernel instances to speed up compilation import argparse -import itertools +from enum import IntEnum from pathlib import Path -from typing import List, Optional, Tuple -from dataclasses import dataclass -import copy -import fnmatch +from typing import List, Optional -DTYPE_MAP = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp8" : "ck_tile::fp8_t" -} - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" -} - -MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", -} - -MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" -} - -BIAS_MAP = { - "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" -} - -# TODO: this is ugly -BIAS_CHECK_MAP = { - "no" : "bias_enum::no_bias", - "bias" : "bias_enum::elementwise_bias", - "alibi" : "bias_enum::alibi" -} - -MODE_MAP = { - "batch" : "false", - "group" : "true" -} - -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} - -PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", -} - -PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", -} - -BOOL_MAP = { - "t" : "true", - "f" : "false" -} - -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", -} - -GEN_DIR = "" # in Cmake, have to generate files in same folder - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - false, - {F_lse}, - {F_dropout}, - {F_squant}, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; - -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; - -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; - -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - -#include - -template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" -FMHA_FWD_API=""" -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} -""" -MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", -} - -MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", -} - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_(s, a); - }} -""" - -def get_mask_map(mask : str): - if mask == "generic": - return MASK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_MAP - else: - assert False - return None - -def get_mask_check_map(mask : str): - if mask == "generic": - return MASK_CHECK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_CHECK_MAP - else: - assert False - return None - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0blen : int - vlayout : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : n += '_lse' - if self.F_dropout == 't' : n += '_dropout' - if self.F_squant == 't' : n += '_squant' - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm : int # number of warps along q seqlen (block warps) - F_rn : int # number of warps along k seqlen(not used) - F_rk : int # number of warps along gemm-k(not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ - f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: - if direction == 'fwd': - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) - } - else: - return None - else: - return None - -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) - else: - assert False - return pipelines +from codegen.cmake_config import * +from codegen.ops import ( + fmha_fwd, + fmha_fwd_splitkv, + fmha_bwd +) - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - for direction, dtype in itertools.product(["fwd"], DTYPE_MAP.keys()): - d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - k = FmhaFwdKernel(direction=direction, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != None: - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - if not cond: - continue - api_pool.register_traits(k.api_trait()) - gen.append(k) +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 - return (api_pool, gen) - -BWD_DQDKDV_PIPELINE_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", -} - -BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", - "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", - "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +handlers = { + 'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), + 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), + 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), } -FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_bwd.hpp" -""" - -FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; -using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; -using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; -using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape -// G0&G2 -> GSdP -// G1&G3 -> GdKV -// G4 -> GdQ -using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; - -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - {F_dbias}, - false, - {F_dropout}, - false, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< - typename FmhaBwdTypeConfig::QDataType, - typename FmhaBwdTypeConfig::KDataType, - typename FmhaBwdTypeConfig::VDataType, - typename FmhaBwdTypeConfig::GemmDataType, - typename FmhaBwdTypeConfig::LSEDataType, - typename FmhaBwdTypeConfig::AccDataType, - typename FmhaBwdTypeConfig::DDataType, - typename FmhaBwdTypeConfig::BiasDataType, - typename FmhaBwdTypeConfig::RandValOutputDataType, - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::QGradDataType, - typename FmhaBwdTypeConfig::KGradDataType, - typename FmhaBwdTypeConfig::VGradDataType, - typename FmhaBwdTypeConfig::BiasGradDataType, - fmha_bwd_shape_{F_idx}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_bwd_trait_{F_idx}>; - -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< - fmha_bwd_pipeline_problem_{F_idx}>; - -using fmha_bwd_dk_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - false, false>>; - -using fmha_bwd_dv_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::AccDataType, - typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - false, false>>; - -using fmha_bwd_dq_dk_dv_kernel_{F_idx} = - ck_tile::FmhaBwdDQDKDVKernel, - fmha_bwd_pipeline_{F_idx}, - fmha_bwd_dk_epilogue_{F_idx}, - fmha_bwd_dv_epilogue_{F_idx}>; - -using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - -#include - -template<> -float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} - -template<> -std::string fmha_bwd_dq_dk_dv_get_name_() -{{ - using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" -FMHA_BWD_API=""" -#include - -template -float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} - ); -}} - -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} -""" - -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - r = fmha_bwd_(s, a); - return r; - }} -""" - -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 256 == 0' # BlockSize - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) - -# GEMM0: Q@K=S^T -# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) -# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) -# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) -# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) -# Is it necessary to distinguish between K0~K4? -@dataclass -class FmhaBwdDQDKDVTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along gemm0 unroll(F_bhdq) - F_bk1 : int # tile size along gemm1 unroll(F_bm0) - F_bk2 : int # tile size along gemm2 unroll(F_bhdv) - F_bk3 : int # tile size along gemm3 unroll(F_bm0) - F_bk4 : int # tile size along gemm4 unroll(F_bn0) - F_bhdq : int # q head_dim - F_bhdv : int # v head_dim - F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 - F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 - F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 - F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 - F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 - F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 - F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 - F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 - F_rk2 : int # number of warps along gemm-k (not used) in gemm4 - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" - -@dataclass -class FmhaBwdDQDKDVKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # - F_dbias : str # - F_dropout : str # - F_mask : str # value from MASK_MAP - F_mode : str # value from MODE_MAP - F_pipeline : str - mask_impl : str - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bk1 = self.F_tile.F_bk1, - F_bk2 = self.F_tile.F_bk2, - F_bk3 = self.F_tile.F_bk3, - F_bk4 = self.F_tile.F_bk4, - F_bhdq = self.F_tile.F_bhdq, - F_bhdv = self.F_tile.F_bhdv, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_rm2 = self.F_tile.F_rm2, - F_rn2 = self.F_tile.F_rn2, - F_rk2 = self.F_tile.F_rk2, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BIAS_MAP[self.F_bias], - F_dbias = BOOL_MAP[self.F_dbias], - F_dropout = BOOL_MAP[self.F_dropout], - F_occupancy = self.F_tile.F_occupancy, - F_mask = get_mask_map(self.mask_impl)[self.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name - if pn != '' : n += f'_{pn}' - if self.F_bias != 'no' : n += f'_{self.F_bias}' - if self.F_dbias == 't' : n += '_dbias' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_dropout == 't' : n += '_dropout' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad) - -# TODO: design a more practical way to do it -# this is current supported tile size & pipeline. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: - if direction == 'bwd': - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "qs_ks_vr_dos"], - '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), - "ks_vr"] - } - else: - return None - else: - return None - -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl) - if kernel_filter != None: - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, - {F_dvpad}, - {F_occupancy}>; - -using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< - typename FmhaBwdTypeConfig::ODataType, - typename FmhaBwdTypeConfig::OGradDataType, - typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 256, - {F_hdim}, - {F_mode}, - fmha_bwd_dot_do_o_trait_{F_idx}>; - -using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< - fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; - -using fmha_bwd_dot_do_o_kernel_{F_idx} = - ck_tile::FmhaBwdOGradDotOKernel, - fmha_bwd_dot_do_o_{F_idx}>; - -using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; - -#include - -template<> -float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); -}} - -template<> -void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); -}} - -template<> -std::string fmha_bwd_dot_do_o_get_name_() -{{ - using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; - return k_::GetName(); -}} -""" - -@dataclass -class FmhaBwdOGradDotOKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_spad : str # true/false - F_dvpad : str # - F_mode : str # value from MODE_MAP - F_occupancy : int - - @property - def template(self) -> str: - return FMHA_BWD_KERNEL_HEADER + \ - FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_spad = BOOL_MAP[self.F_spad], - F_dvpad = BOOL_MAP[self.F_dvpad], - F_mode = MODE_MAP[self.F_mode], - F_occupancy = self.F_occupancy) - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" - if pn != '' : n += f'_{pn}' - return n - - @property - def filename(self) -> str: - return self.name + ".cpp" - -def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for direction, dtype in itertools.product(["bwd"], DTYPE_MAP.keys()): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(direction=direction+"_dot_do_o", F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - gen.append(k) - - return gen - -def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) - -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - if direction == 'fwd': - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_fwd_kernel(kernel, output_dir) - write_fwd_api(api_pool, output_dir) - else: - kernels = get_bwd_dot_do_o_blobs() - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) + + for api in api_list: + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) - with file_path.open('a') as f: - if direction == 'fwd': - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") - else: - kernels = get_bwd_dot_do_o_blobs() - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") + + for api in api_list: + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", - description="gen api for CK fmha kernel", + description="gen API for CK fmha kernel", ) parser.add_argument( "-d", - "--direction", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", default='fwd', - choices=['fwd', 'bwd'], required=False, - help="choose the direction of kernels(default: fwd)" + help="supply API(s) to generate (default: fwd). separated by comma." ) parser.add_argument( "-o", @@ -1251,7 +99,8 @@ if __name__ == "__main__": ) args = parser.parse_args() + api_list = args.direction.split(',') if args.list_blobs is not None: - list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) \ No newline at end of file diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 568486830..057d2b11f 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -10,6 +10,10 @@ #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" @@ -22,6 +26,12 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index ce8493663..c022edf72 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, x_total / num_splits); + const index_t split_start = x_per_split * i_split; + const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp new file mode 100644 index 000000000..6f4313d5b --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -0,0 +1,455 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineKernel +{ + using TilePartitioner = remove_cvref_t; + using FmhaPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(FmhaPipeline::kM0) + "x" + + _TS_(FmhaPipeline::kN1) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + _SS_(FmhaPipeline::name) + + (pn.empty() ? "" : "_" + pn) + + (kStoreLSE ? "_lse" : "" ) + + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* lse_acc_ptr; + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + ck_tile::index_t num_splits; + + ck_tile::index_t row_stride_o_acc; + ck_tile::index_t row_stride_o; + + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct Fp8StaticQuantKargs + { + float scale_o; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_o; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + seqlen_q, + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + batch_stride_o}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + reinterpret_cast(seqstart_q_ptr)}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.row_stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc; + const OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // LSEacc/Oacc DRAM and DRAM windows + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.num_splits, kargs.seqlen_q), + make_tuple(kargs.split_stride_lse_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), + number{}, + number<1>{}); + + auto o_acc_dram_view = pad_tensor_view( + o_acc_dram_naive, + make_tuple(number<1>{}, number{}, number{}), + sequence{}); + + const index_t padded_max_seqlen_q = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; + const index_t padded_hdim_v = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; + + return transform_tensor_view( + o_acc_dram_view, + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_pass_through_transform(padded_hdim_v)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); + + auto lse_acc_dram_window = make_tile_window( + lse_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {0, i_m0}); + + auto o_acc_dram_window = make_tile_window( + o_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {i_m0, i_n1}); + + // LSE DRAM window + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + identity{}, // lse_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + else + { + return FmhaPipeline{}(lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + }(); + + // O DRAM and DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.row_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp new file mode 100644 index 000000000..9f04843a3 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineTilePartitioner +{ + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp new file mode 100644 index 000000000..45ed185ad --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -0,0 +1,913 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdSplitKVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using OaccDataType = remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + ck_tile::index_t num_splits; + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o_acc; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct BatchModeBiasKargs : CommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct AlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct MaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct Fp8StaticQuantKargs + { + float scale_p; + }; + + struct CommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct BatchModeDropoutKargs : CommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse acc + auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { + constexpr auto lse_acc_dram_window_lengths = make_tuple(number{}); + LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse_acc + + batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; + + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = + make_naive_tensor_view(lse_acc_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); + }(); + + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + auto o_acc_tile = [&, i_split_ = i_split]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_acc_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + identity{}, // o_acc_element_func + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_acc_dram_window, + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.hdim_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram_window = + make_tile_window(o_acc_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp new file mode 100644 index 000000000..aec37cb36 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * + ck_tile::integer_divide_ceil(hdim_v, kN1), + nhead * num_splits, + batch_size); + } + + CK_TILE_DEVICE auto + operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1); + const auto [i_nhead, i_split] = f(blockIdx.y, num_splits); + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp new file mode 100644 index 000000000..7efdb798c --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { +namespace detail { +template +struct log2; + +template <> +struct log2<16> : std::integral_constant +{ +}; + +template <> +struct log2<32> : std::integral_constant +{ +}; + +template <> +struct log2<64> : std::integral_constant +{ +}; + +template <> +struct log2<128> : std::integral_constant +{ +}; +} // namespace detail + +template +struct BlockFmhaFwdSplitKVCombinePipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kHeadDimV = Problem::kHeadDimV; + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN1 = Problem::kN1; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr index_t kMaxSplits = Problem::kMaxSplits; + + static constexpr index_t kAlignmentLSE = + kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE(); + static constexpr index_t kAlignmentLSEacc = kAlignmentLSE; + + static constexpr index_t kAlignmentOacc = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kHeadDimV <= 32) + { + constexpr std::array occupancy{3, 3, 3, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 128) + { + constexpr std::array occupancy{3, 3, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 256) + { + constexpr std::array occupancy{2, 2, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + } + }(); + + static constexpr const char* name = "unused"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, + const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, + const LSEElementFunction& lse_element_func, + const OaccElementFunction& o_acc_element_func, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + // lse_acc tile in LDS + LSEDataType* lse_acc_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor()]( + index_t row, index_t col) -> LSEDataType& { + return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))]; + }; + + auto lse_acc_lds_write_window = [&]() { + auto view = make_tensor_view( + lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor()); + return make_tile_window(view, make_tuple(number{}, number{}), {0, 0}); + }(); + + auto lse_acc_dram_window = + make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(), + lse_acc_dram_block_window_tmp.get_window_lengths(), + lse_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakeLSEaccDramTileDistribution()); + + // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). + auto lse_acc_tile = load_tile(lse_acc_dram_window); + store_tile(lse_acc_lds_write_window, lse_acc_tile); + block_sync_lds(); + + auto lse_accum = make_static_distributed_tensor( + Policy::template MakeLSEaccRegTileDistribution()); + + // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) + // this will extend the distributed tensor width so that each thread in wave have data to + // reduce. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_accum(i_j_idx) = lse_acc_lds(row, col); + } + else + { + lse_accum(i_j_idx) = -numeric::infinity(); + } + }); + }); + } + + // compute the logsumexp of the LSE along the split dimension. + const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto lse_max = block_tile_reduce( + lse_accum, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(lse_max, f_max, bool_constant{}); + + static const auto get_validated_m = [](LSEDataType raw_m) { + return raw_m == -numeric::infinity() ? type_convert(0.f) + : raw_m; + }; + + decltype(lse_accum) lse_exp; + { + constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + lse_exp(i_j_idx) = + ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); + }); + }); + } + + auto lse_sum = block_tile_reduce( + lse_exp, sequence<1>{}, f_sum, type_convert(0)); + block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); + + decltype(lse_max) lse_logsum; + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) + { + lse_logsum(i_idx) = numeric::infinity(); + } + else + { + lse_logsum(i_idx) = + ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); + } + }); + } + + // store the lse scales in shared memory. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_acc_lds(row, col) = + ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); + } + }); + }); + } + block_sync_lds(); + + if constexpr(kStoreLSE) + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_logsum(i_idx) == numeric::infinity()) + { + lse_logsum(i_idx) = -numeric::infinity(); + } + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); + } + + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + auto o_acc_dram_window = + make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), + o_acc_dram_block_window_tmp.get_window_lengths(), + o_acc_dram_block_window_tmp.get_window_origin(), + o_acc_dist); + auto o_acc = make_static_distributed_tensor(o_acc_dist); + clear_tile(o_acc); + + const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + + for(index_t i_split = 0; i_split < num_splits; ++i_split) + { + auto o_tile = load_tile(o_acc_dram_window); + { + constexpr auto spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + + const auto row = x_indices.at(number<0>{}); + + const LSEDataType lse_scale = lse_acc_lds(row, i_split); + o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); + }); + }); + } + + move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); + } + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window, + const OaccDramBlockWindow& o_acc_dram_block_window, + LSEDramBlockWindow& lse_dram_block_window, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + return operator()(lse_acc_dram_block_window, + o_acc_dram_block_window, + lse_dram_block_window, + identity{}, + identity{}, + num_splits, + max_seqlen_q, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp new file mode 100644 index 000000000..2eb092f05 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() + { + using LSEDataType = remove_cvref_t; + return 16 / sizeof(LSEDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + return 16 / sizeof(OaccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEaccLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + + constexpr index_t NPerThread = 16 / sizeof(LSEDataType); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; + constexpr index_t TotalWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // 3d + padding, [kMaxSplits, kM0] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lse_acc_lds_block_desc; + } + + // 3d + padding, [kM0, kMaxSplits] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lse_acc_t_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kMPerBlock = Problem::kM0; + + constexpr index_t NThreads = get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / NThreads; + + constexpr index_t MThreads = kBlockSize / NThreads; + constexpr index_t MPerThread = kMPerBlock / MThreads; + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MThreads * MPerThread == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + { + using OaccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + + constexpr index_t N1 = 16 / sizeof(OaccDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = get_warp_size() / N0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp new file mode 100644 index 000000000..a6d74b388 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -0,0 +1,666 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaFwdSplitKVPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 000000000..ae363a497 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), + {0, 0}); + }, + number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse acc + if constexpr(kStoreLSE) + { + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 000000000..6109fa5ab --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 000000000..338319ab3 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 1b72b6005..23b75f16a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -54,4 +54,69 @@ struct BlockFmhaPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem +{ + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; +}; + +template +struct BlockFmhaSplitKVCombinePipelineProblem +{ + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + static constexpr index_t kHeadDimV = HeadDimV_; + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN1 = kN1_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr index_t kMaxSplits = Traits::kMaxSplits; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 973ffa9f8..a59431e39 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -32,6 +32,50 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaFwdSplitKVTraits : TileFmhaTraits +{ + // determine if some split (length) is not divisible by tile size + static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; +}; + +template +struct TileFmhaFwdSplitKVCombineTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + + static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_); + static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0); + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + template -- GitLab From a32b1bc64794b8bd51f6924cb7d72dd1059803dd Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 26 Jun 2024 22:04:52 +0800 Subject: [PATCH 67/96] Replace hipDeviceSynchronize() by hipStreamSynchronize(stream) calls (#1359) --- include/ck_tile/host/timer.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp index e2baeaef7..e5519643b 100644 --- a/include/ck_tile/host/timer.hpp +++ b/include/ck_tile/host/timer.hpp @@ -27,7 +27,7 @@ struct gpu_timer CK_TILE_HOST void start(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); } @@ -51,15 +51,15 @@ struct gpu_timer struct cpu_timer { // torch.utils.benchmark.Timer(), there is a sync inside each timer callback - CK_TILE_HOST void start(const hipStream_t&) + CK_TILE_HOST void start(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); start_tick = std::chrono::high_resolution_clock::now(); } // torch.utils.benchmark.Timer(), there is a sync inside each timer callback - CK_TILE_HOST void stop(const hipStream_t&) + CK_TILE_HOST void stop(const hipStream_t& s) { - HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipStreamSynchronize(s)); stop_tick = std::chrono::high_resolution_clock::now(); } // return in ms -- GitLab From 941d1f7ce0a7529e7e23039baff1da3064a5d10e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:33:34 -0700 Subject: [PATCH 68/96] Merging the gfx12 code into public repo. (#1362) --- CMakeLists.txt | 6 +- Jenkinsfile | 4 +- cmake/EnableCompilerWarnings.cmake | 2 +- example/01_gemm/gemm_wmma_fp16.cpp | 54 +- example/01_gemm/run_gemm_example.inc | 2 +- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 4 +- .../cross_attention_forward_wmma_fp16.cpp | 6 +- .../self_attention_forward_wmma_fp16.cpp | 6 +- example/CMakeLists.txt | 4 +- include/ck/ck.hpp | 14 +- include/ck/host_utility/device_prop.hpp | 5 + .../gpu/block/blockwise_gemm_wmma.hpp | 499 ++++++++++++++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 7 + ...d_contraction_multiple_d_wmma_cshuffle.hpp | 23 +- .../device_batched_gemm_multiple_d_dl.hpp | 7 +- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 9 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 4 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 2 +- .../device/impl/device_fpAintB_gemm_wmma.hpp | 2 +- .../gpu/device/impl/device_gemm_dl.hpp | 2 +- .../device/impl/device_gemm_multiple_d_dl.hpp | 7 +- .../device_gemm_multiple_d_wmma_cshuffle.hpp | 2 +- .../gpu/device/impl/device_gemm_wmma.hpp | 23 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_grouped_conv_bwd_weight_dl.hpp | 5 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 2 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 7 +- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 4 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 2 +- .../device_grouped_gemm_multiple_d_dl.hpp | 7 +- ...e_grouped_query_attention_forward_wmma.hpp | 5 +- ...ice_multi_query_attention_forward_wmma.hpp | 5 +- ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 20 +- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 22 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 44 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 42 +- .../gpu/grid/gridwise_tensor_rearrange.hpp | 5 +- .../threadwise_tensor_slice_transfer.hpp | 109 +++- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 147 +++++- include/ck/utility/amd_wmma.hpp | 82 +++ include/ck/utility/data_type.hpp | 2 +- include/ck/utility/synchronization.hpp | 17 + include/ck_tile/core/config.hpp | 5 +- .../gpu/CMakeLists.txt | 8 +- profiler/src/CMakeLists.txt | 4 +- test/CMakeLists.txt | 4 +- .../test_grouped_convnd_bwd_weight.cpp | 2 +- test/wmma_op/wmma_op_util.hpp | 16 + 49 files changed, 1100 insertions(+), 164 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e8626b2cb..b27e6ab4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,7 +117,7 @@ else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") endif() if(GPU_ARCH MATCHES "gfx90") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") @@ -127,8 +127,10 @@ else() rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + elseif(GPU_ARCH MATCHES "gfx12") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") endif() set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() diff --git a/Jenkinsfile b/Jenkinsfile index 855fe8dff..67e9b2fcb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none - triggers { - parameterizedCron(CRON_SETTINGS) - } options { parallelsAlwaysFailFast() } diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d68..93fd306e9 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 8c52e4f7d..f8afe8d6d 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle - < ALayout, - BLayout, - CLayout, - ADataType, + < ALayout, + BLayout, + CLayout, + ADataType, BDataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CElementOp, - GemmDefault, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, 1, // Prefetch stage 128, // BlockSize 64, // MPerBlock 128, // NPerBlock 64, // KPerBlock - 8, // K1 + 2, // K1 16, // MPerWmma 16, // NPerWmma 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, 1, // C shuffle (M Repeat) Per store 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, + S<1, 32, 1, 4>, 8>; // clang-format on diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index b04e4e53a..cb15186c3 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; case 4: - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); break; case 5: diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index ab19f819e..be47665a2 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 2bbf430c4..f556be887 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = 2, 4, 4, - true, + false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, - true, + false, 1, 1, S<1, 64, 1, 2>, diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index 4c92c5497..fac19f8b5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index 8e037272b..d463cc871 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +277,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index fd9f5cd89..c9781637d 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 32eea551f..9528a30b4 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code @@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 @@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define CK_USE_AMD_MFMA_GFX940 #endif -// WMMA instruction -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_USE_AMD_WMMA -#elif defined(__gfx11__) // for GPU code -#define CK_USE_AMD_WMMA -#endif - // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 116bb3ea0..83af2efe8 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -84,4 +84,9 @@ inline bool is_gfx11_supported() ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 873539f8b..3ea19da74 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -13,6 +13,504 @@ namespace ck { +#ifdef __gfx12__ +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; + + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else template ::type a_thread_copy_; typename BThreadCopySelector::type b_thread_copy_; }; +#endif } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e2296a55f..d3f6344c2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -487,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // sync point. if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("s_barrier" ::); +#endif __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index a15759559..ab3f3856a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; - static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } else { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + if(!(arg.a_kz_stride_ == 1)) { - printf("DeviceOp: Vector Access A-k check failure\n"); - return false; + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 8fd14afc0..1b487502f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -70,8 +70,9 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index f1bc6a226..f0f89f1d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -592,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && - ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && - std::is_same::value) + if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index b84e18130..1edae33be 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl { // check device if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || - ck::is_gfx11_supported())) + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index bf96324d0..553143e28 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index b1784b385..eb0fb55f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 93ab8a7e1..a7cc546f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; - static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - - static constexpr auto AEnableLds_auto = - (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; static constexpr auto BEnableLds_auto = - (MWaves == 1 && is_same::value) ? false : true; + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 6f74838fb..6bb5d431c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 86091aeba..cc26936fe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -48,8 +48,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 211185dfb..5738be0fb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index ce86ec54e..c3fe54b07 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -90,8 +90,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -667,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || - ck::is_gfx103_supported() || ck::is_gfx11_supported())) + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 5c9d63e2b..c6b84b613 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -107,7 +107,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx11__)) + defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -603,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index ac392cddc..060a16d1e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index 4e14ed3a5..cc88c1a10 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -60,7 +60,7 @@ __global__ void bool input_permute, bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -165,6 +165,7 @@ __global__ void ignore = O; ignore = G0; ignore = G1; + ignore = alpha; ignore = input_permute; ignore = output_permute; #endif // end of if (defined(__gfx11__)) @@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 16717ff81..1754e07e6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B0EnableLds) { // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B1EnableLds) { // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 - constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); - constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else constexpr auto B_LRow = I1; +#endif return transform_tensor_descriptor( B1BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 499eb7eb0..21dac6f9e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,7 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 49a6dc3b0..b3b057c80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -147,7 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; @@ -237,7 +237,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; GridwiseOp::template Run(p_a_grid, @@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } @@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 8e4117593..4458b9356 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -45,7 +45,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma } else { + + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - using DefaultBlock2CTileMap = - remove_cvref_t; - struct SharedMemTrait { // LDS allocation for A and B: be careful of alignment @@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma b_block_space_size_aligned * sizeof(BDataType)); }; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + template __device__ static void Run(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 6772524e0..174074990 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -35,8 +35,9 @@ __global__ void const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) GridwiseTensorRearrangeKernel::Run(in_grid_desc, p_in_global, out_grid_desc, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index bcce930fc..d7a6a3624 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; -// Specilized for WMMA +// Specilized for WMMA-Navi3 // A single Wave32 is composed by double row // Data exchange allowed between these two rows // This RowLane Dst buf will be filled from two Src buf @@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ElementwiseOperation element_op_{}; }; +// Specilized for WMMA-Navi4 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 565195f53..9a9ebf559 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,12 +11,17 @@ namespace ck { enum struct WmmaInstr { + // gfx11 wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, wmma_i32_16x16x16_iu8, - wmma_i32_16x16x16_iu4 + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, }; /* @@ -279,6 +284,122 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + template static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_f16; +#endif } template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif } template <> @@ -320,8 +449,13 @@ struct WmmaSelector template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> static constexpr auto GetWmma() @@ -502,6 +636,9 @@ struct WmmaGemm __device__ static auto GetSubGroupId() { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } @@ -516,12 +653,20 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif } __device__ static CIndex GetBeginOfThreadBlk() diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 1bb0140f3..322a0f94b 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 93a1edefb..4df14c621 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -203,7 +203,7 @@ struct vector_type } }; -int static err = 0; +__device__ int static err = 0; template struct vector_type { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 4fe5e3950..d6b6eac26 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -10,12 +10,20 @@ namespace ck { __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ // s_barrier \ // " ::); __builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_barrier(); +#endif #else __syncthreads(); #endif @@ -23,11 +31,20 @@ __device__ void block_sync_lds() __device__ void block_sync_lds_direct_load() { +#ifdef __gfx12__ + asm volatile("\ + s_wait_vmcnt 0x0 \n \ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } __device__ void s_nop() diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 344343d93..83637e18e 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -17,6 +17,9 @@ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" @@ -155,7 +158,7 @@ #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) // for GPU code +#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 05b8c035c..1bcc0f802 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index fa0eb6f88..5262ca33a 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -59,7 +59,7 @@ if(GPU_TARGETS MATCHES "gfx9") endif() -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() @@ -134,7 +134,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 49b67992b..66b4d3d27 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -139,7 +139,7 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 5ef073066..aee80cb2c 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if(ck::is_gfx11_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { // on gfx11 only support for 3d is implemented if constexpr(NDimSpatial{} != 3) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 49782bce6..d9ec94771 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif for(int ele = 0; ele < 16; ++ele) { @@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif // sync threads, similar to mma_sync // __syncthreads(); -- GitLab From ed21948bcd8dd84a7d7cf625c9d8f472ea8f322f Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 27 Jun 2024 11:30:32 +0200 Subject: [PATCH 69/96] Add structural sparsity gemm instruction tests (#1309) * first version of smfmac test * add reviewer comments * add reviewer suggestions --- include/ck/utility/amd_smfmac.hpp | 69 ++++ library/include/ck/library/utility/fill.hpp | 37 +- test/CMakeLists.txt | 3 + test/smfmac_op/CMakeLists.txt | 2 + test/smfmac_op/smfmac_op.cpp | 82 +++++ test/smfmac_op/smfmac_op_util.hpp | 361 ++++++++++++++++++++ test/smfmac_op/smfmac_op_xdl.cpp | 89 +++++ 7 files changed, 642 insertions(+), 1 deletion(-) create mode 100644 include/ck/utility/amd_smfmac.hpp create mode 100644 test/smfmac_op/CMakeLists.txt create mode 100644 test/smfmac_op/smfmac_op.cpp create mode 100644 test/smfmac_op/smfmac_op_util.hpp create mode 100644 test/smfmac_op/smfmac_op_xdl.cpp diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp new file mode 100644 index 000000000..234293085 --- /dev/null +++ b/include/ck/utility/amd_smfmac.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#pragma once + +namespace ck { + +template +struct intrin_smfmac_f32_16x16x32f16; + +template <> +struct intrin_smfmac_f32_16x16x32f16<16, 16> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_16x16x32bf16; + +template <> +struct intrin_smfmac_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_32x32x16f16; + +template <> +struct intrin_smfmac_f32_32x32x16f16<32, 32> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +template +struct intrin_smfmac_f32_32x32x16bf16; + +template <> +struct intrin_smfmac_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); + } +}; + +} // namespace ck diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 4e075df43..333604135 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -133,5 +133,40 @@ struct FillConstant } }; +template +struct TransformIntoStructuralSparsity +{ + // clang-format off + static constexpr T valid_sequences[] = { + 0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + }; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::for_each(first, last, [=, idx = 0](T& elem) mutable { + auto tmp_idx = idx; + idx += 1; + return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + } // namespace utils } // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 66b4d3d27..7ee37d211 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -209,4 +209,7 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() +if(GPU_TARGETS MATCHES "gfx942") + add_subdirectory(smfmac_op) +endif() add_subdirectory(position_embedding) diff --git a/test/smfmac_op/CMakeLists.txt b/test/smfmac_op/CMakeLists.txt new file mode 100644 index 000000000..4ffc423f5 --- /dev/null +++ b/test/smfmac_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_smfmac_op smfmac_op_xdl.cpp) +target_link_libraries(test_smfmac_op PRIVATE utility) diff --git a/test/smfmac_op/smfmac_op.cpp b/test/smfmac_op/smfmac_op.cpp new file mode 100644 index 000000000..de4f9414a --- /dev/null +++ b/test/smfmac_op/smfmac_op.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +template +bool run_test() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool pass = true; + + const auto matmul_default = ck::smfmac_op_util::matmul; + + const auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, 1, 1>{}([&](auto i) { + pass &= + ck::smfmac_op_util::TestSmfmac{}>( + smfmac_kernel_container)), + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + return pass; +} +int main(int, char*[]) +{ + bool pass = true; + // clang-format off + // | Src1Type| Src1VecSize| Src2Type| Src2VecSize| DstType| DstVecSize| GPUAccType| CPUAccType| M| N| K| + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 4, float, float,16,16,32>(); + pass &= run_test(); + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 16, float, float,32,32,16>(); + pass &= run_test(); + // clang-format on + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass; +} diff --git a/test/smfmac_op/smfmac_op_util.hpp b/test/smfmac_op/smfmac_op_util.hpp new file mode 100644 index 000000000..44122c551 --- /dev/null +++ b/test/smfmac_op/smfmac_op_util.hpp @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/amd_smfmac.hpp" +#include "ck/library/utility/fill.hpp" + +namespace ck { +namespace smfmac_op_util { + +template +__device__ void +builtin_smfmac_naive_selector(const src_vec1&, const src_vec2&, const int32_t&, acc_vec&) +{ +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32f16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32bf16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + half4_t, + half8_t, + StaticBufferTupleOfVector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16f16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + bhalf4_t, + bhalf8_t, + StaticBufferTupleOfVector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16bf16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous +// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is +// stored in idx register to allow selection of corresponding B matrix elements for multiplication. +// Currently smfmac instructions support only A matrix as sparse +template +__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c) +{ + __shared__ src1_t a_shared[M * K]; + __shared__ src2_t b_shared[K * N]; + const int lane = threadIdx.x; + // smfmac's A part is storing only non-zero elements in 2VGPRs + // smfmac's B part is storing all elements in 4VGPRs + using src1_vec = typename vector_type::type; + using src1_full_vec = typename vector_type::type; + using src2_vec = typename vector_type::type; + src1_vec a_frag = {}; + src2_vec b_frag = {}; + + src1_full_vec a_temp = {}; + src2_vec b_temp = {}; + // initialize c fragment to 0 + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + for(int i = 0; i < 8; ++i) + { + a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K + } + + for(int i = 0; i < 8; ++i) + { + b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N + } + + __syncthreads(); + + for(int i = 0; i < 8; ++i) + { + a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i]; + } + for(int i = 0; i < 8; ++i) + { + b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i]; + } + + __syncthreads(); + + // Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements. + // It starts with last two elements of every 4 elements subgroup set as non-zero + int32_t idx = 0b11101110; + // Bit masks are for zeroing 0-3rd position of idx + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + + src1_t curr_val; + int32_t a_pos = 0; + for(int j = 0; j < 2; ++j) + { + a_pos = j * 2; + for(int i = 0; i < 4; ++i) + { + curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i]; + if(curr_val != 0.0f) + { + idx &= ~bit_clear_masks[a_pos]; + idx |= (i % 4) << 2 * a_pos; + a_frag[a_pos] = curr_val; + a_pos++; + } + } + } + + for(int i = 0; i < 8; ++i) + { + b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)]; + } + + builtin_smfmac_naive_selector(a_frag, b_frag, idx, c_thread_buf_); + __syncthreads(); + + // store results from unpacked c_thread_buf_ output + if constexpr(K == 32) + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[(4 * (lane / 16) + i) * N + lane % 16] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } + else + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } +} + +struct GemmParams +{ + GemmParams() : M(16), N(16), K(32), StrideA(32), StrideB(16), StrideC(16), alpha(1), beta(0) {} + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestSmfmac +{ + auto PrepareGemmTensor(const ck::smfmac_op_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_n_k, BDataType{}); + ck::utils::TransformIntoStructuralSparsity{}(a_m_k); + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceSmfmac& smfmac_kernel) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + ck::smfmac_op_util::GemmParams params; + params.M = M; + params.N = N; + params.K = K; + params.StrideA = K; // M K + params.StrideB = N; // K N + params.StrideC = N; // M N + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::smfmac_op_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::smfmac_op_util::RunDeviceGEMM(smfmac_kernel, a, b, c_device); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace smfmac_op_util +} // namespace ck diff --git a/test/smfmac_op/smfmac_op_xdl.cpp b/test/smfmac_op/smfmac_op_xdl.cpp new file mode 100644 index 000000000..292fd259e --- /dev/null +++ b/test/smfmac_op/smfmac_op_xdl.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestSmfmac : public ::testing::Test +{ + protected: + using Src1Type = std::tuple_element_t<0, Tuple>; + static constexpr ck::index_t Src1VecSize = std::tuple_element_t<1, Tuple>{}.value; + using Src2Type = std::tuple_element_t<2, Tuple>; + static constexpr ck::index_t Src2VecSize = std::tuple_element_t<3, Tuple>{}.value; + using DstType = std::tuple_element_t<4, Tuple>; + static constexpr ck::index_t AccVecSize = std::tuple_element_t<5, Tuple>{}.value; + using GPUAccType = std::tuple_element_t<6, Tuple>; + using CPUAccType = std::tuple_element_t<7, Tuple>; + static constexpr ck::index_t M = std::tuple_element_t<8, Tuple>{}.value; + static constexpr ck::index_t N = std::tuple_element_t<9, Tuple>{}.value; + static constexpr ck::index_t K = std::tuple_element_t<10, Tuple>{}.value; + + void Run() + { + bool pass = true; + constexpr auto matmul_default = ck::smfmac_op_util::matmul; + + constexpr auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + pass &= ck::smfmac_op_util::TestSmfmac< + std::tuple_element_t, + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + EXPECT_TRUE(pass); + } +}; + +template +using I = ck::Number; + +using KernelTypes = + ::testing::Types, F16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, BF16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, F16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>, + std::tuple, BF16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>>; + +TYPED_TEST_SUITE(TestSmfmac, KernelTypes); +TYPED_TEST(TestSmfmac, TestSmfmacFP16BF16) { this->Run(); } -- GitLab From 3bb0fe6c7e321bb44b1ec247d34e93607f61162c Mon Sep 17 00:00:00 2001 From: alexxu-amd <159800977+alexxu-amd@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:57:58 -0400 Subject: [PATCH 70/96] remove PR trigger for now due to high cost (#1329) --- .azuredevops/rocm-ci.yml | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 8c5285675..4161c2d5a 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -23,20 +23,7 @@ trigger: - Jenkinsfile - LICENSE -pr: - autoCancel: true - branches: - include: - - develop - paths: - exclude: - - .github - - docs - - '.*.y*ml' - - '*.md' - - Jenkinsfile - - LICENSE - drafts: false +pr: none jobs: - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo -- GitLab From fafa567b3cd00993af70b78114429cb6fd4723a0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:09:00 -0700 Subject: [PATCH 71/96] Adding a private docker for ROCm6.2 release candidate. (#1365) * add private docker for rocm6.2_rc1 * update dockerfile --- Dockerfile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0d3807f4d..0c98188b9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,11 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.2" ]; then \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc2" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1736298; \ + elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.2-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.2-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-8 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1794148; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" -- GitLab From 2525864fdae349ce8e839faec8971be915116b0f Mon Sep 17 00:00:00 2001 From: Ruturaj Vaidya Date: Thu, 27 Jun 2024 14:34:25 -0500 Subject: [PATCH 72/96] Update CMakeLists.txt (#1364) It is a good practice to check if the file CMakeLists.txt is in fact in the directory. --- example/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c9781637d..87c5a89f8 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -181,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}") + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) ENDIF() ENDFOREACH() -- GitLab From 614ebd050ad36aa78dd84342e838b53fda7a9996 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:14:36 -0700 Subject: [PATCH 73/96] Bump rocm-docs-core from 1.4.0 to 1.4.1 in /docs/sphinx (#1367) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.4.0 to 1.4.1. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.4.0...v1.4.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 9c9706c66..6605380a5 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.4.0 +rocm-docs-core==1.4.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 313c30026..a3566090e 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.4.0 +rocm-docs-core==1.4.1 # via -r requirements.in six==1.16.0 # via -- GitLab From 497ccb872b6a9a921c01df4dca49dac7cb242c72 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 28 Jun 2024 06:50:46 -0700 Subject: [PATCH 74/96] fix the optional ckProfiler grouped_gemm arguments (#1368) --- profiler/src/profile_grouped_gemm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 25203d7b6..fbf44d720 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[]) int n_iter = 10; if(argc == 17) { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); } #ifdef CK_ENABLE_FP16 -- GitLab From 959073842c0db839d45d565eb260fd018c996ce4 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Wed, 3 Jul 2024 23:34:38 -0700 Subject: [PATCH 75/96] Fix issue with multiple targets and remove smfmac tests from unsupported test targets (#1372) --- Jenkinsfile | 4 +-- .../25_wrapper/wrapper_basic_gemm.cpp | 17 +++++++++-- .../25_wrapper/wrapper_optimized_gemm.cpp | 16 ++++++++-- .../gemm_bilinear_wmma_fp16.cpp | 9 ++++++ .../gemm_bilinear_wmma_int8.cpp | 9 ++++++ ...ouped_conv_fwd_bias_relu_add_wmma_fp16.cpp | 13 ++++++++- ...ouped_conv_fwd_bias_relu_add_wmma_int8.cpp | 13 ++++++++- ...e_scale_softmax_gemm_permute_wmma_fp16.cpp | 13 ++++++++- ...m_scale_softmax_gemm_permute_wmma_fp16.cpp | 13 ++++++++- .../cross_attention_forward_wmma_fp16.cpp | 13 ++++++++- ...uped_query_attention_forward_wmma_fp16.cpp | 13 ++++++++- ...ulti_query_attention_forward_wmma_fp16.cpp | 13 ++++++++- .../self_attention_forward_wmma_fp16.cpp | 13 ++++++++- .../grouped_conv_bwd_data_wmma_fp16.cpp | 13 ++++++++- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 20 ++++++------- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 29 ++++++++++--------- include/ck/utility/amd_smfmac.hpp | 28 ++++++++++++++++++ test/CMakeLists.txt | 6 +++- test/grouped_convnd_bwd_data/CMakeLists.txt | 8 ++--- ...grouped_convnd_bwd_data_interface_wmma.cpp | 8 +++++ test/grouped_convnd_bwd_weight/CMakeLists.txt | 8 ++--- ...ouped_convnd_bwd_weight_interface_wmma.cpp | 8 +++++ test/grouped_convnd_fwd/CMakeLists.txt | 2 +- test/wmma_op/wmma_op_util.hpp | 4 ++- 24 files changed, 243 insertions(+), 50 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 67e9b2fcb..8809fc50c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -886,10 +886,10 @@ pipeline { } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a" \ + -DGPU_TARGETS="gfx1100;gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp index 59c5c243c..23245dd18 100644 --- a/client_example/25_wrapper/wrapper_basic_gemm.cpp +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -7,19 +7,23 @@ #include #include +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/host_tensor.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), @@ -213,3 +225,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp index b6294c239..31e20342d 100644 --- a/client_example/25_wrapper/wrapper_optimized_gemm.cpp +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -7,18 +7,21 @@ #include #include -#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), @@ -305,3 +316,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index d1b820da7..18731e810 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index aca136f80..87812369b 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 039d25029..ff873d26b 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = FP16; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp index 793324970..662a6f611 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = I8; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp index 2c7bacfc4..69ab5c5c0 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp index d9ab645ee..f5cedb14c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index fac19f8b5..41c6dff2d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_cross_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp index 12dcfcc36..955c25f0d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -28,6 +28,7 @@ Example is GQA-4 #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -299,4 +300,14 @@ using ReferenceGemm1Instance = #include "run_grouped_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp index 694a320a4..112be07c4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_ #include "run_multi_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index d463cc871..9ec1bc933 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_self_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp index 5baa52150..3e3ae7edb 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -3,6 +3,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" #include "common.hpp" +#include "ck/host_utility/device_prop.hpp" using OutDataType = FP16; using WeiDataType = FP16; @@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat #include "run_grouped_conv_bwd_data_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index f4f496fc1..d9e300b73 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -47,12 +47,12 @@ __global__ void #endif kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -103,12 +103,12 @@ __global__ void #endif kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 415ae3d49..a4d4a01a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -69,14 +69,15 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n, - const index_t groups_count) + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group @@ -132,13 +133,13 @@ __global__ void #endif kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n, - const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp index 234293085..abb8d9f5e 100644 --- a/include/ck/utility/amd_smfmac.hpp +++ b/include/ck/utility/amd_smfmac.hpp @@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16> __device__ static void Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16> __device__ static void Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32> __device__ static void Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; @@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32> __device__ static void Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) { +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif } }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7ee37d211..3b121fc30 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -209,7 +213,7 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() -if(GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 add_subdirectory(smfmac_op) endif() add_subdirectory(position_embedding) diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 3507989ba..8edb71520 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp index c0429c6d0..fbb6ffc6f 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 54b514e7a..313b5ba4c 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility) endif() add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp index 1dcb8f866..2e2f5332a 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1eba91382..f611e6624 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,6 +1,6 @@ if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) - if(GPU_TARGETS MATCHES "gfx11") + if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) else() target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index d9ec94771..3e511ab5b 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/utility/amd_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace wmma_op_util { @@ -373,7 +374,8 @@ struct TestWmma a, b, c_host, a_element_op, b_element_op, c_element_op); // Act - bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + bool is_supported = ck::is_gfx11_supported() && + ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); if(is_supported) { -- GitLab From eaa870a1ab91fedfc614609fbb6e843ae5231dca Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Jul 2024 12:00:14 +0200 Subject: [PATCH 76/96] Add structural sparsity xdlops (#1363) * Implemented smfmac xdlops * add reviewer comments --- .../gpu/warp/smfmac_xdlops_gemm.hpp | 409 ++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp new file mode 100644 index 000000000..33c07f34f --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -0,0 +1,409 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_smfmac.hpp" + +namespace ck { + +enum struct SmfmacInstr +{ + smfmac_f32_16x16x32f16 = 0, + smfmac_f32_32x32x16f16, + smfmac_f32_16x16x32bf16, + smfmac_f32_32x32x16bf16, +}; + +template +struct smfmac_type; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32bf16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16bf16::Run(a, b, idx, reg_c); + } +}; + +template +struct SmfmacSelector +{ + template + static constexpr auto GetSmfmac(); + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32bf16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16bf16; + } + + static constexpr auto selected_smfmac = + smfmac_type()>{}; + + __host__ __device__ constexpr SmfmacSelector() + { + static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk == + selected_smfmac.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks == + selected_smfmac.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks || + selected_smfmac.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size == + selected_smfmac.m_per_blk * selected_smfmac.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_smfmac.is_k_reduction || + (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) * + selected_smfmac.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; } +}; + +template +struct SparseXdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks); + } + + __host__ __device__ constexpr SparseXdlopsGemm() + { + static_assert(NPerXdlops == 16 || NPerXdlops == 32, + "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(MPerXdlops == 16 || MPerXdlops == 32, + "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk, + smfmac_instr.num_input_blks, + smfmac_instr.group_size)), + make_pass_through_transform(smfmac_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / smfmac_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; } + + template + __device__ void + Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value, + "base base_type must be half or bfloat16!"); + + static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size; + + return CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto smfmac = + SmfmacSelector{}; + + static constexpr auto smfmac_instr = smfmac.selected_smfmac; + + static constexpr auto KPerXdlops = smfmac.GetKPerXdlops(); + static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck -- GitLab From 75e622f02f964523dc2c60125904ce4018fae0f2 Mon Sep 17 00:00:00 2001 From: Harisankar Sadasivan <135730918+hsadasiv@users.noreply.github.com> Date: Fri, 5 Jul 2024 21:40:30 -0700 Subject: [PATCH 77/96] Universal streamk with atomics (#1360) * universal streamk with atomics with ckprofiler support. grid_size and streamk strategy are tunable. grid_size of -1 leads to #WGs = maximum occupancy X num_CUs. implementation supports many different streamk policies: 1-tile, 2-tile, 3-tile and 4-tile. streamk strategy of -1 leads to default streamk policy (4-tile). * Update README.md * fixing clang-format issues * removed conflicts in struct members between streamk and universal streamk * corrected arg parsing for streamk and universal streamk * added stream-k policies for 3 tile and 4 tile * fixed argument type issue with parsing cmd args * changes suggested in PR review are made- removing comments and correcting copyright * file permissions updated * added default value support for grid_size and streamk-policy selection set to -1 * print messages for arguments * print messages for arguments * print messages for arguments1 --- .pre-commit-config.yaml | 0 example/01_gemm/CMakeLists.txt | 2 + example/01_gemm/README.md | 18 + example/01_gemm/common.hpp | 69 +- example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp | 48 + .../01_gemm/run_gemm_example_streamk_v2.inc | 298 +++ .../gpu/device/device_gemm_streamk_v2.hpp | 44 + .../device_gemm_xdl_cshuffle_streamk_v3.hpp | 556 +++++ .../gpu/grid/block_to_ctile_map.hpp | 322 +++ .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 2010 +++++++++++++++++ .../gpu/gemm_universal_streamk.hpp | 337 +++ .../gpu/gemm_universal_streamk/CMakeLists.txt | 26 + ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 91 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 31 + ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 98 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 31 + .../gemm_universal_streamk/CMakeLists.txt | 26 + ...universal_streamk_f16_f16_f16_mk_kn_mn.hpp | 91 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 31 + ...universal_streamk_f16_f16_f16_mk_nk_mn.hpp | 98 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 30 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 30 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 30 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 31 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 31 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 31 + .../profile_gemm_universal_streamk_impl.hpp | 332 +++ profiler/src/CMakeLists.txt | 2 + .../src/profile_gemm_universal_streamk.cpp | 156 ++ script/check_copyright_year.sh | 0 61 files changed, 5846 insertions(+), 2 deletions(-) mode change 100644 => 100755 .pre-commit-config.yaml create mode 100644 example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp create mode 100644 example/01_gemm/run_gemm_example_streamk_v2.inc create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100644 => 100755 profiler/src/CMakeLists.txt create mode 100644 profiler/src/profile_gemm_universal_streamk.cpp mode change 100755 => 100644 script/check_copyright_year.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 23683de44..98fd9c6b7 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) +add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3) add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md index a09e69255..5edec1f04 100644 --- a/example/01_gemm/README.md +++ b/example/01_gemm/README.md @@ -7,3 +7,21 @@ #arg3: run kernel # of times (>1) ./bin/example_gemm_xdl 0 1 5 ``` + +# Instructions for ```example_gemm_xdl_fp16_streamk_v3``` + +## Run ```example_gemm_xdl_fp16_streamk_v3``` +```bash +arg1: verification (0=no, 1=yes) +arg2: initialization (0=no init, 1=integer value, 2=decimal value) +arg3: time kernel (0=no, 1=yes) +arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK) +arg11: Grid_size(-1 for max occupancy) +bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1 +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1} +Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2 +``` diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index ef87d9c2f..3d8f4565c 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,6 +45,19 @@ struct ProblemSizeStreamK final ck::index_t NumSKBlocks = -1; }; +struct ProblemSizeStreamK_universal final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK +}; struct ProblemSizeSplitK final { @@ -123,6 +136,57 @@ bool parse_cmd_args(int argc, return true; } +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK_universal& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.Streamk_sel = std::stoi(argv[10]); + problem_size.Grid_size = std::stoi(argv[11]); + } + } + else + { + std::cerr + << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + return false; + } + + return true; +} + template <> bool parse_cmd_args(int argc, char* argv[], @@ -165,7 +229,8 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: NumSKBlocks(optional)" << std::endl; + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp new file mode 100644 index 000000000..5b163962b --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc new file mode 100644 index 000000000..6679f9515 --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto Grid_size = problem_size.Grid_size; + auto Streamk_sel = problem_size.Streamk_sel; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) { + if(streamk_sel == -1) + { + return static_cast(4); + } + else + return static_cast(streamk_sel); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Streamk_sel = f_get_default_streamk_policy(Streamk_sel); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2_Streamk_Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + Streamk_sel, + Grid_size, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_universal_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK_universal problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp new file mode 100644 index 000000000..1a4d684f1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Streamk_V2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t Streamk_sel, + ck::index_t Grid_size, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 000000000..452063156 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + float ave_time = 0; + + index_t k_grain = KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + hipGetErrorString(hipMemsetAsync( + arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); + const auto Run = [&](const auto& kernel) { + dim3 grid_dim; + if(arg.Grid_size < 0) + { + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, 0); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + arg.Grid_size = num_cu * occupancy; + grid_dim = arg.Grid_size; + } + else + grid_dim = arg.Grid_size; + + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_); + } + else + { + + ave_time = launch_and_time_kernel( + stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + + return Argument{ + p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +struct BlockToCTileMap_GemmStreamK_v2 +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + mutable uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv equiv_tiles_big; // for reduction + MDiv equiv_tiles_little; // for reduction + + // prefer construct on host + __host__ __device__ BlockToCTileMap_GemmStreamK_v2( + uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) + { + // total output tiles + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + + // default to regular DP GEMM if sk blocks == 0 + if(streamk_sel == 0) + { + sk_num_blocks = 0; + dp_tiles = num_tiles; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + // 2-tile sk + DP GEMM + else + { + + // check if there's enough work for DP+ stream-k + bool bigEnough = num_tiles > grid_size; + // select between stream-k strategies + uint32_t sk_tiles = 0; + if(streamk_sel == 1) // 1 tile stream-k + { + sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 2) // 2-tile stream-k + { + sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 3) // 3-tile stream-k + { + sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) + : num_tiles; + } + else if(streamk_sel == 4) // 4-tile stream-k + { + sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) + : num_tiles; + } + sk_num_blocks = sk_tiles; + // remaining tiles are DP tiles + dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; + + sk_total_iters = k_iters_per_tile.get() * sk_tiles; + + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = sk_num_blocks; + } + + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + // using multiple blocks for parallel reduction + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ index_t get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + return reduction_start_block_idx + get_sk_tiles(); + } + else + return reduction_start_block_idx; + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& equiv_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 000000000..ff1021535 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,2010 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_streamk_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + Streamk_sel{Streamk_sel_}, + Grid_size{Grid_size_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, 1)}, + KPadded{CalculateKPadded(K_, 1)}, + AK0{CalculateAK0Padded(K_, 1)}, + BK0{CalculateBK0Padded(K_, 1)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel + << ", Grid size:" << Grid_size << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t Streamk_sel; + mutable index_t Grid_size; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K) + { + if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead * problem.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead * problem.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead; + } + + if(kbatch_id < static_cast(problem.KBatch - 1)) + { + problem.K = problem.KRead; + } + else + { + problem.K = orig_K - problem.KRead * (problem.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + + if(karg.K <= 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(is_same, bhalf_t>::value) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk( + problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + { + + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } + } +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp new file mode 100644 index 000000000..19fa6c209 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); +#endif +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemm_Streamk_V2; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 000000000..2a930ab9a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 000000000..6e8d5c798 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..6adcb8f4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..631ae6872 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..2c49773a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..39d54fb88 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..8ee50d63c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..d31e0819a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..fe19f35e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..6c1873b37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..ffd53f406 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..094b8f92f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 000000000..e00c1733e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 000000000..546f909b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..d91de96be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..c70678b44 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..5410a0cc2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..4ae7329f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..4fc5458a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..7369f87a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..45425a41a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..3b5ac0366 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..53aa011a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 000000000..2a930ab9a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 000000000..6e8d5c798 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..6adcb8f4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..631ae6872 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..2c49773a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..39d54fb88 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..8ee50d63c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..d31e0819a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..fe19f35e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..6c1873b37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..ffd53f406 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..094b8f92f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 000000000..e00c1733e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 000000000..546f909b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..d91de96be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..c70678b44 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..5410a0cc2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..4ae7329f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..4fc5458a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..7369f87a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..45425a41a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..3b5ac0366 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..53aa011a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp new file mode 100644 index 000000000..72194e8e6 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_universal_streamk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int Streamk_sel, + int Grid_size, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemm_Streamk_V2; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_grid_size = 0; + float best_streamk_sel = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + std::vector streamk_sel_list = { + 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, + // 2:2-tile Stream-K + DP + + if(Grid_size == -1) + { + grid_size_list = {Grid_size}; + } + if(Streamk_sel != -1) + { + streamk_sel_list = {Streamk_sel}; + } + for(std::size_t j = 0; j < streamk_sel_list.size(); j++) + { + for(std::size_t i = 0; i < grid_size_list.size(); i++) + { + auto grid_size_curr = grid_size_list[i]; + index_t streamk_sel_curr = streamk_sel_list[j]; + printf("streamk_sel_curr=%0d\n", streamk_sel_curr); + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel_curr, + grid_size_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", Grid_size " + << grid_size_curr << ", streamk selection strategy" + << streamk_sel_curr << std::endl; + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_grid_size = grid_size_curr; + best_streamk_sel = streamk_sel_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC + << " Grid_size = " << best_grid_size + << " Stream-K selection strategy = " << best_streamk_sel << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt old mode 100644 new mode 100755 index 5262ca33a..c2a976972 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -52,6 +52,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) @@ -120,6 +121,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp new file mode 100644 index 000000000..cd3f5787d --- /dev/null +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_universal_streamk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 + F16_F16_F16_F8, // 6 +}; + +#define OP_NAME "gemm_universal_streamk" +#define OP_DESC "Universal Streamk GEMM" + +int profile_gemm_universal_streamk(int argc, char* argv[]) +{ + if(argc != 16 && argc != 19) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " + "comp f8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: Stream-k select strategy 0: all DP, 1: 1-tile SK, 2: 2-tile SK\n"); + printf("arg15: Grid-size, -1 for max persistent kernel occupancy\n"); + printf("optional:\n"); + printf("arg16: number of warm-up cycles (default 1)\n"); + printf("arg17: number of iterations (default 10)\n"); + printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int Streamk_sel = std::stoi(argv[14]); + const int Grid_size = std::stoi(argv[15]); + + int n_warmup = 20; + int n_iter = 50; + uint64_t rotating = 0; + if(argc == 19) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + rotating = std::stoull(argv[18]) * 1024 * 1024; + } + + using F32 = float; + using F16 = ck::half_t; + // using BF16 = ck::bhalf_t; + // using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_universal_streamk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + Streamk_sel, + Grid_size, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal_streamk); diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh old mode 100755 new mode 100644 -- GitLab From eb44e0472a2d58b623007b056173817793c9df3d Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 8 Jul 2024 13:55:54 -0400 Subject: [PATCH 78/96] Add ckProfiler support for forward 3D convolutions with OUT element-wise operations. (#1354) --- .../ck/library/utility/host_tensor.hpp | 10 +- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 352 ++++++++++++++++++ profiler/src/CMakeLists.txt | 3 + .../profile_grouped_conv_fwd_outelementop.cpp | 220 +++++++++++ .../profile_grouped_conv_fwd_outelementop.sh | 20 + 5 files changed, 604 insertions(+), 1 deletion(-) create mode 100644 profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp create mode 100644 profiler/src/profile_grouped_conv_fwd_outelementop.cpp create mode 100755 script/profile_grouped_conv_fwd_outelementop.sh diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index ddbd16ad9..493b992ac 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) first = false; else os << delim; - os << static_cast(v); + + if constexpr(std::is_same_v || std::is_same_v) + { + os << ck::type_convert(v); + } + else + { + os << static_cast(v); + } } return os; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp new file mode 100644 index 000000000..bd756eb82 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -0,0 +1,352 @@ +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace profiler { + +template +inline constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + auto pass = true; // return status + + using CShuffleDataType = float; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = PassThrough; + using WeiElementOp = PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // random scale values + auto scale_in = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_wei = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_out = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // run reference op + if(do_verification) + { + + std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; + std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7) + << get_rtol() << ", " << get_atol() << ")" << std::endl; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c2a976972..198f49432 100755 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -57,6 +57,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) endif() @@ -134,6 +135,8 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) endif() if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") diff --git a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp new file mode 100644 index 000000000..196a2cf3f --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp @@ -0,0 +1,220 @@ +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK = 0, + NHWGC_GKYXC_NHWGK = 1 +}; + +enum struct OutElementOp +{ + ConvScale = 0, + ConvInvScale = 1 +}; + +enum struct ConvDataType +{ + F8_F8_F8 = 0, + BF8_BF8_F8 = 1, + F8_BF8_F8 = 2, + BF8_F8_F8 = 3 +}; + +#define OP_NAME "grouped_conv_fwd_outelementop" +#define OP_DESC "Grouped Convolution Forward+Elementwise Operation" + +static void print_helper_msg() +{ + // clang-format off + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp8, Weight fp8, Output fp8\n" + << " 1: Input bf8, Weight bf8, Output fp8\n" + << " 2: Input fp8, Weight bf8, Output fp8\n" + << " 3: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: element-wise operation (0: ConvScale\n" + << " 1: ConvInvScale)\n" + << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_outelementop(int argc, char* argv[]) +{ + + // 9 total, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto op = static_cast(std::stoi(argv[3])); + const auto layout = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0] + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using ConvScale = ck::tensor_operation::element_wise::ConvScale; + using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto out_element_op, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using OutElementOp = decltype(out_element_op); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(op == OutElementOp::ConvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, ConvScale{}, F8{}, BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, ConvScale{}, BF8{}, F8{}); + } + } + else if(op == OutElementOp::ConvInvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvInvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + BF8{}, + F8{}, + ConvInvScale{}, + F8{}, + BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + F8{}, + F8{}, + ConvInvScale{}, + BF8{}, + F8{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_outelementop); diff --git a/script/profile_grouped_conv_fwd_outelementop.sh b/script/profile_grouped_conv_fwd_outelementop.sh new file mode 100755 index 000000000..ac444a25c --- /dev/null +++ b/script/profile_grouped_conv_fwd_outelementop.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +OUTELEMENTOP=$3 +LAYOUT=$4 +VERIFY=$5 +INIT=$6 +LOG=$7 +TIME=$8 + +N=$9 + +####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1 +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1 -- GitLab From 8182976c37433808b5e3a27a6536d1b74b0c23a1 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 9 Jul 2024 02:09:55 +0800 Subject: [PATCH 79/96] [CK_TILE] wa prec, remove sgpr offset for inline asm (#1356) * wa prec, remove sgpr offset for inline asm * macro for set tile * ignore unused param if no kernel instances in host API * fix more prec issue * cache buffer resource * fix * support pre-nop * clear tile by vector type members * add workaround to reduce scratch memory * conditionally enable workaround code * enable workaround start from certain build version * fallback set_tile() implementation from certain build version * undo template argument changes * put dummy asm in load_raw() * fix comments, refactor s_nop inside buffer_load --------- Co-authored-by: PoYen, Chen --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 4 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 + .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 3 + .../core/arch/amd_buffer_addressing.hpp | 518 +++++++++++------- include/ck_tile/core/arch/arch.hpp | 8 +- include/ck_tile/core/config.hpp | 9 + include/ck_tile/core/tensor/buffer_view.hpp | 45 +- include/ck_tile/core/tensor/load_tile.hpp | 19 +- .../ck_tile/core/tensor/null_tile_window.hpp | 2 + include/ck_tile/core/tensor/tensor_view.hpp | 24 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 56 +- include/ck_tile/core/tensor/tile_window.hpp | 100 +++- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 33 +- 13 files changed, 581 insertions(+), 243 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 0160915a5..0df115dc3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -271,7 +271,9 @@ class FmhaBwdApiPool: per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) # GEMM0: Q@K=S^T diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1486671f6..137d3a2f7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -278,6 +278,9 @@ class FmhaFwdApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 419fbaaea..509394509 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) @dataclass diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 2cd8bb5f0..7f488d1b7 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer> { using payloa } // namespace impl // TODO: glc/slc/... -template +template struct buffer_load; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) -template <> -struct buffer_load<16> +template +struct buffer_load<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<8> +template +struct buffer_load<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<4> +template +struct buffer_load<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<2> +template +struct buffer_load<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<1> +template +struct buffer_load<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template +template struct buffer_load_if; -template <> -struct buffer_load_if<16> +template +struct buffer_load_if<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<8> +template +struct buffer_load_if<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<4> +template +struct buffer_load_if<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<2> +template +struct buffer_load_if<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<1> +template +struct buffer_load_if<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" @@ -294,17 +379,16 @@ struct buffer_store<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; - asm volatile( - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -315,17 +399,16 @@ struct buffer_store<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; - asm volatile( - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -336,17 +419,16 @@ struct buffer_store<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -357,17 +439,16 @@ struct buffer_store<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); using mbuf_t = short; - asm volatile( - "buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -378,17 +459,16 @@ struct buffer_store<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -402,21 +482,20 @@ struct buffer_store_if<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -431,7 +510,7 @@ struct buffer_store_if<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { @@ -439,14 +518,13 @@ struct buffer_store_if<8> auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -461,21 +539,20 @@ struct buffer_store_if<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -490,21 +567,20 @@ struct buffer_store_if<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -519,21 +595,20 @@ struct buffer_store_if<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -CK_TILE_DEVICE void async_buffer_load_dword(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); + else + asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, @@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, using type = thread_buffer; if constexpr(oob_conditional_check) { - buffer_load_if{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load_if{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } else { - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } } template + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0) + index_t src_immediate_addr_offset = 0, + bool_constant = {}) { static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, - index_t is_valid_element = 0) + index_t is_valid_element = 0, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t is_valid_element = 0, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, // buffer_load OOB still working. template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_element_space_size) + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); @@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); } // buffer_store requires: diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 4a69f67ae..65a3a4e2f 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() " ::); } -CK_TILE_DEVICE void s_nop() +CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 - asm volatile("\ - s_nop 0 \n \ - " ::); + asm volatile("s_nop %0" : : "n"(cnt) :); #else - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(cnt); #endif } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 83637e18e..fa28aa2be 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -21,6 +21,7 @@ #define __gfx12__ #endif +#include "hip/hip_version.h" #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -147,6 +148,14 @@ #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #endif +#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE +#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091 +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index ffe8f7a4f..ed705c91e 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -69,6 +69,8 @@ struct buffer_view invalid_element_value_ = T{0}; CK_TILE_HOST_DEVICE constexpr buffer_view() - : p_data_{}, buffer_size_{}, invalid_element_value_{} + : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size, T invalid_element_value) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + : p_data_{p_data}, + buffer_size_{buffer_size}, + cached_buf_res_{0}, + invalid_element_value_{invalid_element_value} { } + // this is non constexpr intentially (will call some intrinsic internally) + // Must call for buffers that need *_raw load/store + CK_TILE_HOST_DEVICE void init_raw() + { + cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); + } + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() { return address_space_enum::global; @@ -333,12 +346,15 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t i, + bool is_valid_element, + bool_constant = {}) const { constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -349,18 +365,21 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - dst, p_data_, i, buffer_size_, is_valid_element); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check, pre_nop>( + dst, cached_buf_res_, i, is_valid_element, bool_constant{}); } // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, + index_t i, + bool /*is_valid_element*/, + bool_constant = {}) const { // X is vector of T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -371,8 +390,8 @@ struct buffer_view, t_per_x, Coherence>( - smem, p_data_, i, buffer_size_); + amd_async_buffer_load_with_oob_raw, t_per_x, Coherence>( + smem, cached_buf_res_, i, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -627,6 +646,8 @@ struct buffer_view + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto load_tile_raw(T& tile, const tile_window_with_static_distribution& tile_window, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}); + tile_window.load_raw(tile, bool_constant{}, bool_constant{}); } template + index_t NumCoord, + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const tile_window_with_static_distribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}, + bool_constant = {}) { - return tile_window.async_load(lds_tile); + return tile_window.async_load_raw( + lds_tile, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 89806203a..9707f2990 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -35,6 +35,8 @@ struct null_tile_window CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + CK_TILE_DEVICE void init_raw() {} + WindowLengths window_lengths_; }; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 656309532..4655eec24 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -36,6 +36,8 @@ struct tensor_view { } + CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); } + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() @@ -85,30 +87,34 @@ struct tensor_view // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE void - get_vectorized_elements_raw(remove_cvref_t& dst, - const TensorCoord& coord, - bool_constant = {}) const + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}, + bool_constant = {}) const { - return buf_.template get_raw( + return buf_.template get_raw( dst, coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); } template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, - const TensorCoord& coord) const + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( + remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const { - return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + return buf_.template async_get_raw( + smem, coord.get_offset(), true /*not used*/, bool_constant{}); } // X is vector of DataType. diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 5fecd19dc..79018b9ce 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // sub-dword tensor... -template -CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +template +CK_TILE_DEVICE void +set_tile(DstrTensors& dstr_tensor, number, bool_constant = {}) { - constexpr index_t tensor_bytes = - DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); - if constexpr(v == 0 && tensor_bytes % 4 == 0) + using elem_type = typename DstrTensors::DataType; + constexpr index_t elem_size = sizeof(elem_type); + + constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size; + + // # bytes per write = 4 + if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt) { +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + auto& buffer = dstr_tensor.get_thread_buffer(); + + static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) { + if constexpr(elem_size == 1) + { + // # elements per write = 4 + constexpr auto values = ext_vector_t{0, 0, 0, 0}; + + buffer[i_write * 4 + 0] = values.x; + buffer[i_write * 4 + 1] = values.y; + buffer[i_write * 4 + 2] = values.z; + buffer[i_write * 4 + 3] = values.w; + } + else if constexpr(elem_size == 2) + { + // # elements per write = 2 + constexpr auto values = ext_vector_t{0, 0}; + + buffer[i_write * 2 + 0] = values.x; + buffer[i_write * 2 + 1] = values.y; + } + else if constexpr(elem_size == 4) + { + // # elements per write = 1 + constexpr elem_type value = 0; + + buffer[i_write] = value; + } + else + { + static_assert(false, "type not supported"); + } + }); +#else using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; +#endif } else { - tile_elementwise_inout( - [](auto& x) { x = type_convert(v); }, - dstr_tensor); + tile_elementwise_inout([](auto& x) { x = type_convert(v); }, + dstr_tensor); } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2c38c6aa2..70f381db1 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -344,9 +344,10 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -373,7 +374,13 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); @@ -384,7 +391,8 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, - bool_constant{}); + bool_constant{}, + pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -399,12 +407,17 @@ struct tile_window_with_static_distribution } }); }); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile("; this inline asm is workaround to prevent compiler from using too much " + "scratch memory" ::); +#endif } // TODO: currently async load only implemented in inline asm - template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - bool_constant = {}) const + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -449,11 +462,17 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord); + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -668,6 +687,67 @@ struct tile_window_with_static_distribution }); } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index e9a14ca5a..8251627e6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync return Problem::kBlockPerCu; else { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + if constexpr(kK0BlockLength <= 32) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && @@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well @@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); + k_dram_window, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync { // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail -- GitLab From a328df25a131b5c1d30cbeadd4255ff39f19f977 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 8 Jul 2024 21:21:16 -0700 Subject: [PATCH 80/96] Fix the cmake logic when building with INSTANCES_ONLY=ON. (#1376) * fix the cmake logic when building for various targets * another minor fix --- example/CMakeLists.txt | 4 ++-- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 8 ++++---- test/CMakeLists.txt | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 87c5a89f8..45cfee4de 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 1bcc0f802..2081422e3 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -59,7 +59,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -177,7 +177,7 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -185,11 +185,11 @@ FOREACH(subdir_path ${dir_list}) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3b121fc30..fc1bcfdb2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -60,7 +60,7 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -141,7 +141,7 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() -- GitLab From ccfdc5302238198e0e0e5c0c3f05f41b79fcebb8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 9 Jul 2024 20:30:07 +0800 Subject: [PATCH 81/96] update owner (#1377) * remove zjing14, add poyenc * remove yigex --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bc49ac166..de17acb9c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +* @junliume @illsilin @carlushuang @aosewski @poyenc # Documentation files -docs/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex -*.md @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex -*.rst @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex -.readthedocs.yaml @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +docs/* @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc # Header directory for Doxygen documentation -library/include/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +library/include/* @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc -- GitLab From da42a889645c03d80e61423531fecfdc188c2ab9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:48:23 -0700 Subject: [PATCH 82/96] Bump rocm-docs-core from 1.4.1 to 1.5.0 in /docs/sphinx (#1374) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.4.1 to 1.5.0. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.4.1...v1.5.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sam Wu <22262939+samjwu@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 6605380a5..51bfef289 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.4.1 +rocm-docs-core==1.5.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a3566090e..6d2fe6ca5 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.4.1 +rocm-docs-core==1.5.0 # via -r requirements.in six==1.16.0 # via -- GitLab From 860f957c22b99b9b6add0b1b9132fd78079f3d10 Mon Sep 17 00:00:00 2001 From: Sam Wu <22262939+samjwu@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:36:10 -0600 Subject: [PATCH 83/96] Update changelog release headers (#1378) * Update doc codeowner syntax * Add doc link to changelog * Update changelog formatting for markdownlint Also change headings for releases --- .github/CODEOWNERS | 4 ++-- CHANGELOG.md | 27 ++++++++++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index de17acb9c..1809abebb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ * @junliume @illsilin @carlushuang @aosewski @poyenc # Documentation files -docs/* @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc *.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc *.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc .readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc # Header directory for Doxygen documentation -library/include/* @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc diff --git a/CHANGELOG.md b/CHANGELOG.md index fb2ba1975..dec6334cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,38 +1,46 @@ # Changelog for Composable Kernel -Full documentation for Composable Kernel is not yet available. +Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## CK for ROCm 6.1.0 +## Composable Kernel 1.1.0 for ROCm 6.1.0 ### Additions + * Added generic instances for GEMM XDL operations (#1161) * Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133) * Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) * Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124) ### Optimizations + * New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135) ### Fixes + * Reduced the build time for most GPU architectures (#1084) * Fixed some conversion issues for fp8 data type (#1099) ### Changes + None ### Known issues + None -## CK for ROCm 6.0.0 +## Composable Kernel 1.1.0 for ROCm 6.0.0 ### Fixes - * Fixed a hazard associated with inline v_dot (#808) - * Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + +* Fixed a hazard associated with inline v_dot (#808) +* Fixed two bugs in grouped convolution backward data without K padding (#848 #876) ### Optimizations + None ### Additions + * Added an image to a column kernel (#867) * Added a column to an image kernel (#930) * Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) @@ -42,18 +50,22 @@ None * Support for Batched GEMM DL (#732) ### Changes - * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) -## CK 0.2.0 for ROCm 5.7.0 +* Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + +## Composable Kernel 0.2.0 for ROCm 5.7.0 ### Fixes + * Fixed a bug in 6-dimensional kernels (#555) * Fixed a test case failure with grouped convolution backward weight (#524) ### Optimizations + * Improved the performance of the normalization kernel ### Additions + * New CMake flags: * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types @@ -71,4 +83,5 @@ None * MaxPool and AvgPool forward (#815); MaxPool backward (#750) ### Changes + None -- GitLab From a8eb872055f1f741fad8033611ca6c8aacdd10a8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:54:04 -0700 Subject: [PATCH 84/96] [gfx12] add gfx12 to the default target list (#1379) --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b27e6ab4f..fc0cc4ddb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,7 +112,7 @@ message("checking which targets are supported") #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) @@ -148,7 +148,7 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() - if (GPU_TARGETS MATCHES "gfx11") + if (GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() -- GitLab From f914c228c697351112c1d25499123e51cbf1b5e9 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:28:11 -0700 Subject: [PATCH 85/96] [Jenkins] restore cron jobs (#1380) * test the cron trigger * fix the cron jobs * restore the list of cron jobs --- Jenkinsfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Jenkinsfile b/Jenkinsfile index 8809fc50c..f65ddab5c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -661,6 +661,9 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none + triggers { + parameterizedCron(CRON_SETTINGS) + } options { parallelsAlwaysFailFast() } -- GitLab From 98a01bbc72fabd4a15d1472357104913a35d619e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 11 Jul 2024 13:22:40 -0700 Subject: [PATCH 86/96] Add CK_TILE tests to daily CI builds. (#1381) * add ck_tile tests to CI * build and run ck_tile tests on gfx90a and gfx942 in parallel * fix groovy syntax * turn ck_tile tests OFF by default * skip creating the build folder * build ck_tile examples with 64 threads * build ck_tile examples with cmake-ck-dev.sh script * add video group to docker on mi300 * do not retry to rebuild the early CI stages * help prevent jenkins false failure * restore cron trigger --- Jenkinsfile | 64 ++++++++++++++++++- .../ck_tile/01_fmha/script/benchmark_bwd.sh | 0 .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 3 +- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 4 +- 4 files changed, 66 insertions(+), 5 deletions(-) mode change 100644 => 100755 example/ck_tile/01_fmha/script/benchmark_bwd.sh mode change 100644 => 100755 example/ck_tile/01_fmha/script/smoke_test_bwd.sh diff --git a/Jenkinsfile b/Jenkinsfile index f65ddab5c..e9d55992d 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -315,6 +315,10 @@ def buildHipClangJob(Map conf=[:]){ if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME @@ -366,6 +370,11 @@ def runCKProfiler(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def variant = env.STAGE_NAME @@ -653,7 +662,7 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1; +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1; RUN_CK_TILE_TESTS=true 0 21 * * * % ROCMVERSION=6.1;hipTensor_test=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false @@ -724,6 +733,10 @@ pipeline { name: "RUN_CODEGEN_TESTS", defaultValue: true, description: "Run the codegen tests (default: ON)") + booleanParam( + name: "RUN_CK_TILE_TESTS", + defaultValue: false, + description: "Run the ck_tile tests (default: OFF)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, @@ -816,7 +829,6 @@ pipeline { beforeAgent true expression { params.RUN_CODEGEN_TESTS.toBoolean() } } - options { retry(2) } agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" @@ -833,6 +845,52 @@ pipeline { } } } + } + stage("Run CK_TILE Tests") + { + parallel + { + stage("Run CK_TILE Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/smoke_test_fwd.sh && \ + example/ck_tile/01_fmha/script/smoke_test_bwd.sh""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/smoke_test_fwd.sh && \ + example/ck_tile/01_fmha/script/smoke_test_bwd.sh""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } } stage("Build CK and run Tests") { @@ -973,7 +1031,7 @@ pipeline { beforeAgent true expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } } - options { retry(2) } + options { retry(1) } agent{ label rocmnode("gfx90a")} environment{ setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """ diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh old mode 100644 new mode 100755 diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh old mode 100644 new mode 100755 index 9fe795471..d6830aa2e --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -8,7 +8,7 @@ export CK_WARMUP=0 export CK_REPEAT=1 COMMON_ARGS='-v=1' - +set -x for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 ; do @@ -31,3 +31,4 @@ done done done done +set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 63813e079..779e8d09e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -10,7 +10,7 @@ export CK_REPEAT=1 COMMON_ARGS='-v=1 -warmup=0 -repeat=1' # mode=0 # export HIP_VISIBLE_DEVICES=4 - +set -x for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do @@ -40,6 +40,7 @@ done done done + for perm in 0 1 ; do for bias in "n" "e" "a" ; do for b in 1 2 ; do @@ -49,3 +50,4 @@ done done done done +set +x -- GitLab From 7a46a91c840be960e14c480341eb3ad8f4d08ab7 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 11 Jul 2024 15:31:39 -0500 Subject: [PATCH 87/96] Add instances for grouped conv fwd 3d with ConvScale for bf8@fp8->fp8 (#1369) * Add an example * Add instances * Add a client example --- .../24_grouped_conv_activation/CMakeLists.txt | 4 + .../conv3d_fwd_convscale_bf8_fp8.cpp | 50 +++++++++++ .../62_convnd_activ/convscale/CMakeLists.txt | 3 + .../convnd_fwd_xdl_convscale_bf8_fp8.cpp | 88 +++++++++++++++++++ ...ped_conv_fwd_xdl_outelementop_instance.hpp | 37 ++++++++ .../grouped_convolution_forward_convscale.hpp | 24 +++++ .../CMakeLists.txt | 3 +- ...e_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp | 62 +++++++++++++ 8 files changed, 270 insertions(+), 1 deletion(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp create mode 100644 example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index a624302db..77e54f1f1 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -51,6 +51,10 @@ target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel: add_executable(client_conv3d_fwd_convscale_fp8_bf8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp) target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) # Bwd data bilinear add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp new file mode 100644 index 000000000..192c4fdcb --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 3de1aff67..9264da24a 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -12,6 +12,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) + add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) + set(target 1) endif() endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp new file mode 100644 index 000000000..8590d0620 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp index 0576873b8..e3bec1751 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp @@ -147,6 +147,43 @@ using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp index 9de072369..63dcdc605 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp @@ -70,6 +70,22 @@ void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_ins ConvScale, F8, BF8>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances); #endif template && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + op_ptrs); + } #endif } return op_ptrs; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt index b4cfd1a23..c7f4a3527 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -2,6 +2,7 @@ set(GROUPED_CONV3D_FWD_CONVSCALE xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp new file mode 100644 index 000000000..8e2c0eb1b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 13c1e64daa47061f7b95ba7fec3d1d2b605191e4 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 11 Jul 2024 20:08:07 -0500 Subject: [PATCH 88/96] add gemm_bias_add example (#1361) * add gemm_bias_add example * changed strideD * clang-format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../contraction_multi_ABD_xdl_fp8.cpp | 314 ++++++++++++++++++ .../65_gemm_multiply_multiply/CMakeLists.txt | 3 +- .../gemm_add_add_xdl_fp16.cpp | 270 +++++++++++++++ ...cpp => gemm_multiply_multiply_xdl_fp8.cpp} | 0 4 files changed, 586 insertions(+), 1 deletion(-) create mode 100644 example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp create mode 100644 example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp rename example/65_gemm_multiply_multiply/{gemm_multiply_multiply_xdl_fp16.cpp => gemm_multiply_multiply_xdl_fp8.cpp} (100%) diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp new file mode 100644 index 000000000..eaabccdf2 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using EDataType = F16; +using ComputeDataType = F8; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::f8_t& a, const ck::f8_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = Multiply; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 1, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 1, 0}; + // B0[N0, N1, K0, K1] + std::vector b0_ns_ks_lengths{32, 64, 32, 64}; + std::vector b0_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // B1[N0, N1, K0, K1] + std::vector b1_ns_ks_lengths{32, 64, 32, 64}; + std::vector b1_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + + std::cout << "b0_ns_ks: " << b0_ns_ks.mDesc << std::endl; + std::cout << "b1_ns_ks: " << b1_ns_ks.mDesc << std::endl; + + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 2>{b0_ns_ks_lengths, b1_ns_ks_lengths}, + std::array, 2>{b0_ns_ks_strides, b1_ns_ks_strides}, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + PassThrough{}); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + + for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) + { + for(size_t n1 = 0; n1 < b_ns_ks.mDesc.GetLengths()[1]; ++n1) + { + for(size_t k0 = 0; k0 < b_ns_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < b_ns_ks.mDesc.GetLengths()[3]; ++k1) + { + b_element_op(b_ns_ks(n0, n1, k0, k1), + b0_ns_ks(n0, n1, k0, k1), + b1_ns_ks(n0, n1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = ref_op.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index f3594d153..d968bdb9d 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_gemm_multiply_multiply_xdl_fp16 gemm_multiply_multiply_xdl_fp16.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) +add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp new file mode 100644 index 000000000..5fea43ffc --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct AddAdd +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c + d0 + d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD, StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp similarity index 100% rename from example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp rename to example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp -- GitLab From 82e8a78a3f5ed8906162bf48d22fdf525c99aa12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 12 Jul 2024 20:08:42 +0200 Subject: [PATCH 89/96] Support access per groups and filter3x3 in grouped conv fwd (#1382) * Support access per groups and filter3x3 in grouped conv fwd * Fixes for large cases * Fixes for large tensors --- .../convolution_forward_specialization.hpp | 4 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 76 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 127 +- .../device/impl/device_grouped_conv_utils.hpp | 16 + .../transform_conv_bwd_weight_to_gemm_v2.hpp | 90 +- .../transform_conv_fwd_to_gemm.hpp | 1017 +++++++++++++---- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 8 +- ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 96 ++ .../gpu/grouped_convolution_forward.hpp | 13 + ..._convolution_forward_xdl_merged_groups.inc | 112 ++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 5 + ...groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 48 + ..._groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 48 + ..._groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 48 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 + ...ups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 47 + ...oups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 47 + ...oups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 47 + .../test_grouped_convnd_fwd.cpp | 10 +- 19 files changed, 1499 insertions(+), 364 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index adfa1689c..0eef827a5 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, + Filter3x3, }; inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) @@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::OddC: return "OddC"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index d9e300b73..e18b8b9e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle @@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle NPerBlock, K1Number, KPerBlock / K1Number, - NumBatchToMerge, + NumGroupsToMerge, ConvBackwardWeightSpecialization>{}; static constexpr auto conv_to_gemm_transformer_v1 = @@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge); + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); float ave_time = 0; @@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, @@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, false, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>; @@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle remove_reference_t< DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, ComputePtrOffsetOfStridedBatch, - NumBatchToMerge, + NumGroupsToMerge, false, InMemoryDataOperationEnum::Set, minimum_occupancy>; @@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } - if constexpr(NumBatchToMerge > 1) + if constexpr(NumGroupsToMerge > 1) { // support only if whole M and N can be proccessed on one block if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) @@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { return false; } - if(arg.Conv_G_ % NumBatchToMerge != 0) + if(arg.Conv_G_ % NumGroupsToMerge != 0) { return false; } @@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " << "BlkGemmPipelineVersion: " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " - << NumBatchToMerge + << NumGroupsToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5a8d4e9f..2ee17c5a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -86,7 +86,6 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t groups_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -101,14 +100,11 @@ __global__ void defined(__gfx94__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); - - const long_index_t e_batch_offset = + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); - const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -121,14 +117,14 @@ __global__ void DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; BsPointer p_bs_grid_grp; - const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); // compute_ptr_offset_of_n_ not need BatchStrideB so // in case of MultiA is false but isMultiB is true @@ -139,27 +135,27 @@ __global__ void static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); static_for<0, NumATensor, 1>{}([&](auto i) { - p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i]; + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; }); } else { const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); static_for<0, 1, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; }); + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); } - const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); GridwiseGemm::template Run( p_as_grid_grp, p_bs_grid_grp, p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -172,19 +168,19 @@ __global__ void } else { - const long_index_t a_batch_offset = + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); GridwiseGemm::template Run( - p_as_grid + a_batch_offset + a_n_offset, - p_bs_grid + b_batch_offset, + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -200,7 +196,6 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = groups_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -287,7 +282,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD= 1); + static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -319,7 +317,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I3 = Number<3>{}; static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + TransformConvFwdToGemm{}; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -550,7 +548,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { static_for<0, NumATensor, 1>{}([&](auto i) { // Init compute_ptr_offset_of_groups_ for multiple AB - compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_(i) = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // type is not tuple) @@ -578,7 +577,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle }); static_for<0, NumBTensor, 1>{}([&](auto i) { // Init compute_ptr_offset_of_groups_ for multiple AB - compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_(i) = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; using DataType = remove_cvref_t>; // It is possible that one of the AB is a pointer and one is a tuple. @@ -598,8 +598,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers @@ -616,7 +618,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_(i) = static_cast(p_ds[i]); // D batch stride - compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; @@ -624,7 +627,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); }); - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; // populate desc for Ds/E @@ -745,8 +748,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N; - const index_t gdz = 1; + const index_t gdy = arg.num_group_ / NumGroupsToMerge; + const index_t gdz = num_workgroups_per_Conv_N; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -795,7 +798,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count as_grid_desc_ak0_m_ak1, bs_grid_desc_bk0_n_bk1, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -839,7 +841,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, @@ -871,6 +872,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // check device if(get_device_name() == "gfx908") { @@ -919,6 +924,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + return false; + } + if(G % NumGroupsToMerge != 0) + { + return false; + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } // check vector access of A // FIXME: layout @@ -928,11 +969,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t C = arg.a_g_n_c_wis_lengths_[2]; - + // Check access per C if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { - return false; + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && C == 1 && + is_NSpatialGK_GKSpatial_NSpatialGC() && + G % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } } } else @@ -949,8 +995,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v) { - const index_t C = arg.b_g_k_c_xs_lengths_[2]; - if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { return false; @@ -974,8 +1018,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v) { - const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { valid = false; @@ -1020,8 +1062,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t K = arg.e_g_n_k_wos_lengths_[2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { return false; @@ -1172,7 +1212,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle << BBlockTransferSrcScalarPerVector << ", " << CDEBlockTransferScalarPerVector_NPerBlock << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle + << CShuffleNXdlPerWavePerShuffle << ", " + << NumGroupsToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index c20e5d36f..3ee02558f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC() is_same_v; } +template +constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC() +{ + return is_NWGK_GKXC_NWGC() || + is_NHWGK_GKYXC_NHWGC() || + is_NDHWGK_GKZYXC_NDHWGC(); +} + +template +constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC() +{ + return is_GNWK_GKXC_GNWC() || + is_GNHWK_GKYXC_GNHWC() || + is_GNDHWK_GKZYXC_GNDHWC(); +} + template struct ComputePtrOffsetOfStridedBatch { diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 158890d7a..bc290d564 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -27,7 +27,7 @@ template struct TransformConvBwdWeightToGemmV2 { @@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2 const index_t BatchStride = output_strides[0]; const index_t WoStride = output_strides[4]; const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K), + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2 if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C), + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C), make_tuple(WiStride, BatchStride, CStride)); } else { return make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, NumBatchToMerge, C), + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); } } @@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2 const auto KStride = weights_strides[1]; const auto XStride = weights_strides[4]; const auto BatchStride = weights_strides[0]; - // Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder // for Batch+N dimension const auto desc = make_naive_tensor_descriptor( - make_tuple(NumBatchToMerge, K, Y * X, 1, C), + make_tuple(NumGroupsToMerge, K, Y * X, 1, C), make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); - // Padd 1 to NumBatchToMerge + // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( desc, - make_tuple(make_pass_through_transform(NumBatchToMerge), + make_tuple(make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K), make_pass_through_transform(Y * X), - make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pad_transform(1, 0, NumGroupsToMerge - 1), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); // We need only matrices from diagonal. Xor returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. - static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || - NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || - NumBatchToMerge == 64); + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K), make_pass_through_transform(Y * X), make_pass_through_transform(C)), @@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2 // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), - make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2 const index_t BatchStride = output_strides[0]; const index_t WoStride = output_strides[5]; const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K), + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), make_tuple(WoStride, BatchStride, KStride)); } @@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2 if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C), + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C), make_tuple(WiStride, BatchStride, CStride)); } else { return make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C), + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); } } @@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2 const auto KStride = weights_strides[1]; const auto XStride = weights_strides[5]; const auto BatchStride = weights_strides[0]; - // Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension const auto desc = make_naive_tensor_descriptor( - make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C), + make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C), make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); - // Padd 1 to NumBatchToMerge + // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( desc, - make_tuple(make_pass_through_transform(NumBatchToMerge), + make_tuple(make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(K), make_pass_through_transform(Z * Y * X), - make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pad_transform(1, 0, NumGroupsToMerge - 1), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); // We need only matrices from diagonal. Xor returns 0 for the same // values. So if matrices is not on diagonal then it will be stored in padding. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. - static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || - NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || - NumBatchToMerge == 64); + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); const auto unmerged_padded_desc = transform_tensor_descriptor( padded_desc, - make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_pass_through_transform(K), make_pass_through_transform(Z * Y * X), make_pass_through_transform(C)), @@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2 // Merge To M, N return transform_tensor_descriptor( unmerged_padded_desc, - make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), - make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))), + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))), make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); } @@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t InRightPadW = input_right_pads[1]; const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K * NumBatchToMerge; - const index_t GemmN = C * X * Y * NumBatchToMerge; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * Y * NumGroupsToMerge; const auto PadGemmM = MPerBlock - GemmM % MPerBlock; const auto PadGemmN = NPerBlock - GemmN % NPerBlock; @@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2 in_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2 make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), @@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)), + make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)), make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); @@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t InRightPadW = input_right_pads[2]; const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K * NumBatchToMerge; - const index_t GemmN = C * Z * X * Y * NumBatchToMerge; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; const auto PadGemmM = MPerBlock - GemmM % MPerBlock; const auto PadGemmN = NPerBlock - GemmN % NPerBlock; @@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2 in_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2 out_grid_desc, make_tuple( make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2 make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( in_n_z_do_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)), + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)), make_merge_transform(make_tuple(N, Do, Ho, Wo))), make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 2a02d2534..8dd657301 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -25,11 +25,17 @@ __host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T return init; } -template +template struct TransformConvFwdToGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static long_index_t calculate_element_space_size_impl(const std::array& lengths, @@ -117,13 +123,18 @@ struct TransformConvFwdToGemm const std::array& input_right_pads, const index_t N) { - const index_t C = a_g_n_c_wis_lengths[2]; + const index_t C = a_g_n_c_wis_lengths[I2]; - const index_t Wi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[I3]; - const index_t Wo = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[I3]; - const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const auto CStride = a_g_n_c_wis_strides[I2]; + const index_t WiStride = a_g_n_c_wis_strides[I3]; if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) @@ -132,41 +143,135 @@ struct TransformConvFwdToGemm N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) + device::ConvolutionForwardSpecialization::Filter3x3) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + const index_t ConvDilationW = conv_filter_dilations[0]; - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + const index_t InLeftPadW = input_left_pads[0]; - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const index_t InRightPadW = input_right_pads[0]; + if constexpr(NumGroupsToMerge == 1) + { - return in_gemmm_gemmk_desc; + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -174,40 +279,67 @@ struct TransformConvFwdToGemm const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[0]; const index_t InRightPadW = input_right_pads[0]; - - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t WiStride = a_g_n_c_wis_strides[3]; - const auto CStride = I1; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -242,51 +374,160 @@ struct TransformConvFwdToGemm const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[1]; + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t HiStride = a_g_n_c_wis_strides[I3]; + const index_t WiStride = a_g_n_c_wis_strides[I4]; + if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride)); + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; - return in_gemmm_gemmk_desc; + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -302,42 +543,81 @@ struct TransformConvFwdToGemm const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadW = input_right_pads[1]; - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t HiStride = a_g_n_c_wis_strides[3]; - const index_t WiStride = a_g_n_c_wis_strides[4]; - const auto CStride = I1; - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { - return in_gemmm_gemmk_desc; + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -375,6 +655,13 @@ struct TransformConvFwdToGemm const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[2]; + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t DiStride = a_g_n_c_wis_strides[I3]; + const index_t HiStride = a_g_n_c_wis_strides[I4]; + const index_t WiStride = a_g_n_c_wis_strides[I5]; + if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -382,49 +669,182 @@ struct TransformConvFwdToGemm N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - // This is different - const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial]; - const auto CStride = I1; - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride)); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C), + make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) + device::ConvolutionForwardSpecialization::Filter3x3) { - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } else { @@ -444,53 +864,107 @@ struct TransformConvFwdToGemm const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadW = input_right_pads[2]; - // This is different - const index_t NStride = a_g_n_c_wis_strides[1]; - const index_t DiStride = a_g_n_c_wis_strides[3]; - const index_t HiStride = a_g_n_c_wis_strides[4]; - const index_t WiStride = a_g_n_c_wis_strides[5]; - const auto CStride = I1; - - const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } } @@ -499,9 +973,8 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */) + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) { const index_t K = b_g_k_c_xs_lengths[1]; const index_t C = b_g_k_c_xs_lengths[2]; @@ -509,10 +982,54 @@ struct TransformConvFwdToGemm const index_t YX = ck::accumulate_n( b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - const auto wei_gemmn_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + const index_t GStride = b_g_k_c_xs_strides[I0]; + const index_t KStride = b_g_k_c_xs_strides[I1]; + const index_t CStride = b_g_k_c_xs_strides[I2]; - return wei_gemmn_gemmk_desc; + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, Number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(YX * C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } } template < @@ -585,17 +1102,53 @@ struct TransformConvFwdToGemm { const index_t K = c_g_n_k_wos_lengths[2]; - const auto KStride = I1; + const index_t KStride = I1; const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + const index_t GStride = c_g_n_k_wos_strides[0]; const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); - - return out_gemmm_gemmn_desc; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, K), + make_tuple(WoStride, KStride)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1), + make_tuple(WoStride, GStride, KStride, GStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } } // for output bias diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 77d372843..41303d2e9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -40,10 +40,10 @@ template using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp new file mode 100644 index 000000000..96baf6bb0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ec5bd785a..0233d6d85 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -17,6 +17,7 @@ #endif #ifdef CK_USE_XDL #include "grouped_convolution_forward_xdl.inc" +#include "grouped_convolution_forward_xdl_merged_groups.inc" #include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc" #include "grouped_convolution_forward_mem_intra_xdl.inc" @@ -199,6 +200,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( op_ptrs); @@ -212,6 +215,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( op_ptrs); @@ -227,6 +232,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( op_ptrs); @@ -284,6 +291,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( op_ptrs); @@ -338,6 +347,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( op_ptrs); @@ -353,6 +364,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc new file mode 100644 index 000000000..fe09d3f6a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 4e002c722..170625a6a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + # merged groups + # NHWGC, GKYXC, NHWGK + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp #mem # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 000000000..6fa4bc6e4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 000000000..9fa56f48c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 000000000..e226dae97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index e24dbcd2c..5be667272 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -9,6 +9,10 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 000000000..cf1fcec98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..bea62892d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 000000000..de4472541 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 21fe7992a..1bfc18313 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -104,6 +104,7 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D) this->conv_params.push_back({1, 2, 32, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 96, 1, 1, 1, {3}, {512}, {1}, {1}, {1}, {1}}); this->template Run<1>(); } @@ -119,6 +120,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } @@ -137,6 +140,8 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {4, 30, 160}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } @@ -144,6 +149,9 @@ TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) { // Case larger than 2GB this->conv_params.push_back( - {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {0, 0}, {0, 0}, {0, 0}}); + {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); } -- GitLab From eca39050c6f5b8b90282cd9e0a5b26440ecb6577 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 16 Jul 2024 08:44:46 -0700 Subject: [PATCH 90/96] add Rosty and Bartek to code owners (#1392) --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1809abebb..459315e58 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @junliume @illsilin @carlushuang @aosewski @poyenc +* @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk # Documentation files -docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc -*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc -*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc -.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -- GitLab From 1ff4f25138f0aa498de24fc0a10c88585844a272 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Tue, 16 Jul 2024 23:46:48 +0800 Subject: [PATCH 91/96] Disbale failed instance in rocm6.2 rel (#1388) --- ...gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp | 3 +- .../profiler/profile_gemm_universal_impl.hpp | 38 +++++++++---------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index 452a9c963..f2eb52b49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -43,7 +43,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // Disable due to test failure + // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 362a5dccd..7fcadd7f7 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -191,7 +191,24 @@ bool profile_gemm_universal_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif if(do_log) { @@ -230,25 +247,6 @@ bool profile_gemm_universal_impl(int do_verification, << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; -#if defined CK_ENABLE_FP8 - // set softer tolerances for fp8 - if constexpr(is_same_v || is_same_v || - is_same_v) - { - std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; - pass = pass & ck::utils::check_err( - c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); - } - else - { -#endif - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); -#if defined CK_ENABLE_FP8 - } -#endif - if(tflops > best_tflops) { best_op_name = op_name; -- GitLab From 802a8a1df1f0b67d95702ef28b777e489d8b4de7 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Tue, 16 Jul 2024 11:51:49 -0400 Subject: [PATCH 92/96] Adding more instances of grouped convolution 3d forward for FP8 with ConvScale element-wise operation and ReLU activation. (#1386) * Add CMakePresets configurations. * Add ConvScale+ReLU Functor and an Example * Account for ReLU FLOPs. * Add instances of 3D convolutions with ConvscaleRelu operation. * Implement Client Example * Cleanup --- .../24_grouped_conv_activation/CMakeLists.txt | 4 + .../common.hpp | 316 ++++++++++++++++++ .../conv3d_fwd_convscale_relu_fp8.cpp | 50 +++ example/62_convnd_activ/CMakeLists.txt | 1 + .../convscale_relu/CMakeLists.txt | 11 + .../convnd_fwd_convscale_relu_common.hpp | 302 +++++++++++++++++ .../convnd_fwd_xdl_convscale_relu_fp8.cpp | 86 +++++ .../run_convnd_fwd_convscale_relu_example.inc | 104 ++++++ .../element/unary_element_wise_operation.hpp | 25 ++ ...ped_convolution_forward_convscale_relu.hpp | 105 ++++++ .../CMakeLists.txt | 5 + ..._relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp | 63 ++++ 12 files changed, 1072 insertions(+) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp create mode 100644 example/62_convnd_activ/convscale_relu/CMakeLists.txt create mode 100644 example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp create mode 100644 example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp create mode 100644 example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 77e54f1f1..37bace920 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -39,6 +39,10 @@ target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE c add_executable(client_conv3d_fwd_convinvscale_fp8 grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp) target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale + ReLU +add_executable(client_conv3d_fwd_convscale_relu_fp8 + grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_relu_fp8 PRIVATE composable_kernel::device_conv_operations) # Fwd convscale add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp new file mode 100644 index 000000000..ee188429b --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale_relu( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 elementwise Relu + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScaleRelu, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp new file mode 100644 index 000000000..4003dc7c8 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_relu( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index fa5606773..96d868de0 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(binary) add_subdirectory(convinvscale) add_subdirectory(convscale) +add_subdirectory(convscale_relu) add_subdirectory(multi_AB) add_subdirectory(unary) diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt new file mode 100644 index 000000000..95589cedc --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -0,0 +1,11 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale_relu) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8 ) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp new file mode 100644 index 000000000..d2dacc205 --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::cout << std::endl; + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 element-wise relu + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp new file mode 100644 index 000000000..360349e7e --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_relu_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_relu_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc new file mode 100644 index 000000000..797146060 --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 75429554a..c9ca88374 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1025,6 +1025,31 @@ struct ConvScale float scale_out_; }; +struct ConvScaleRelu +{ + __host__ __device__ ConvScaleRelu(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + float x; + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + // support fastconvert of int8 to fp16 template diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp new file mode 100644 index 000000000..ad86d066f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt new file mode 100644 index 000000000..c60df5a73 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt @@ -0,0 +1,5 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVSCALE_RELU + xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 000000000..472da0da7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScaleRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 4c3107fdcbc5898570b08fe4db035489420c1dc5 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 16 Jul 2024 09:19:23 -0700 Subject: [PATCH 93/96] [ASAN builds] Modify the list of default targets for ASAN builds. (#1389) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add a build parameter to build only XNACK targets * use ENABLE_ASAN_PACKAGING flag to set targets for ASAN builds --------- Co-authored-by: Bartłomiej Kocot --- CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index fc0cc4ddb..b3421e67e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,8 +111,16 @@ message("checking which targets are supported") #These targets will be filtered and only supported ones will be used #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) + if(NOT ENABLE_ASAN_PACKAGING) + #build CK for all supported targets rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + else() + #build CK only for xnack-supported targets + rocm_check_target_ids(DEFAULT_GPU_TARGETS + TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + endif() else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) -- GitLab From 9cac2827930b0a35367edca4fb3081e0882aaf0f Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Tue, 16 Jul 2024 18:52:44 +0200 Subject: [PATCH 94/96] An option whether to colorize output during build (#1390) --- CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index b3421e67e..7e21a7976 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -450,6 +450,13 @@ if(BUILD_DEV) endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + add_compile_options(-fcolor-diagnostics) +endif() +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9) + add_compile_options(-fdiagnostics-color=always) +endif() + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") -- GitLab From ee768148f0701262e17787067b965e4d5a850d89 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 18 Jul 2024 00:15:05 +0800 Subject: [PATCH 95/96] Replace the using of __expf by __ocml_exp_f32 to work-around the test_softmax_rank4 failure (#1394) --- .../gpu/element/unary_element_wise_operation.hpp | 6 +++--- include/ck/utility/math_v2.hpp | 4 ++-- include/ck_tile/core/numeric/bfloat16.hpp | 5 ++++- include/ck_tile/core/numeric/float8.hpp | 4 ++-- include/ck_tile/core/numeric/half.hpp | 2 +- include/ck_tile/core/numeric/math.hpp | 2 +- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c9ca88374..bf4a1c800 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -431,7 +431,7 @@ struct Relu // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // host code use higher accuracy "exp" and "div" -// gpu code use lower accuracy "__expf" and "rcp" function +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { template @@ -451,7 +451,7 @@ struct FastGelu y = x / (1.f + emu); } - // device code, use lower precision "__expf" and "rcp" + // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { @@ -459,7 +459,7 @@ struct FastGelu const float c1 = -2.0 * 0.035677f; const float c2 = -2.0 * 0.797885f; const float u = x * (c1 * x * x + c2); - const float emu = __expf(u); + const float emu = __ocml_exp_f32(u); y = x * ck::math::rcp(1.f + emu); } diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 2b921cdc7..d961cdb19 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -839,7 +839,7 @@ inline __device__ T rcp(T x) template inline __device__ T exp(T x) { - return ck::type_convert(__expf(ck::type_convert(x))); + return ck::type_convert(__ocml_exp_f32(ck::type_convert(x))); }; template <> @@ -851,7 +851,7 @@ inline __device__ half_t exp(half_t x) template <> inline __device__ float exp(float x) { - return __expf(x); + return __ocml_exp_f32(x); }; template <> diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 071387163..4fdf8f9da 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x) }; CK_TILE_DEVICE -bfloat16_t exp(bfloat16_t x) { return static_cast(__expf(static_cast(x))); }; +bfloat16_t exp(bfloat16_t x) +{ + return static_cast(__ocml_exp_f32(static_cast(x))); +}; CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 56ca44e72..b3b1a1f3f 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -835,7 +835,7 @@ CK_TILE_DEVICE fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -fp8_t exp(fp8_t x) { return static_cast(__expf(static_cast(x))); }; +fp8_t exp(fp8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); }; @@ -860,7 +860,7 @@ CK_TILE_DEVICE bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -bf8_t exp(bf8_t x) { return static_cast(__expf(static_cast(x))); }; +bf8_t exp(bf8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 752145f71..acb6eb6c3 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -374,7 +374,7 @@ half_t sqrt(half_t x) }; CK_TILE_DEVICE -half_t exp(half_t x) { return static_cast(__expf(static_cast(x))); }; +half_t exp(half_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index d4984363d..9970bb369 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -519,7 +519,7 @@ CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; CK_TILE_DEVICE -float exp(float x) { return __expf(x); }; +float exp(float x) { return __ocml_exp_f32(x); }; CK_TILE_HOST float exp(float x) { return std::expf(x); } -- GitLab From ab250afda0e99e770634b7994b272c9f6d4f0e7a Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:41:33 -0700 Subject: [PATCH 96/96] add docker for rocm6.2_rc3 (#1401) --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0c98188b9..196b0ee1c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,11 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.2" ]; then \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc1" ]; then \ + elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc3" ]; then \ sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.2-20.04-1_all.deb --no-check-certificate" && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.2-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-8 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1794148; \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-45 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=2003709; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" -- GitLab