diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 4161c2d5a4e54e731a356656bbff8864326c7fee..b37b8cc27fcc2b5f7419dc36b854fe8962ba3734 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -14,6 +14,7 @@ trigger: branches: include: - develop + - amd-develop paths: exclude: - .github diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d7a6b17783894b28bf0aca9fe45a32e6b0c2d480..f6ab388e2a509281e6c595b0c58f28cfb8da979c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca +* @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj # Documentation files -docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca -*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca -*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca -.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b3fcabec34a842d218c19f8411361a17f50a4600..8a988ad1c9e4dc50b57feb2bd3eed542e2801a30 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,6 +7,7 @@ Please describe the motivation behind the pull request, whether it enables a new Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally +- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request diff --git a/CMakeLists.txt b/CMakeLists.txt index d5d4cc64a9113fa861eb43ca5abcec1dfc9ec1fe..37962c14e3032c5344c1c1c0d520cdb92bbfcdd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,14 +92,26 @@ endif() add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) +add_compile_options(-Wno-unique-object-duplication) if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") endif() +if(DPP_KERNELS) + add_definitions(-DDPP_KERNELS) + set(CK_ENABLE_DPP_KERNELS "ON") +endif() option(CK_USE_CODEGEN "Enable codegen library" OFF) if(CK_USE_CODEGEN) - add_definitions(-DCK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) +endif() + +option(CK_TIME_KERNEL "Enable kernel time tracking" ON) +if(CK_TIME_KERNEL) + add_definitions(-DCK_TIME_KERNEL=1) +else() + add_definitions(-DCK_TIME_KERNEL=0) endif() include(getopt) @@ -185,17 +197,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") + add_definitions(-DCK_USE_AMD_MFMA_GFX950) +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() @@ -203,6 +218,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") +endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) @@ -525,7 +544,13 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() +# make check runs the entire set of examples and tests add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +# make smoke runs the tests and examples that runs within 30 seconds on gfx90a +add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") +# make regression runs the tests and examples that runs for more 30 seconds on gfx90a +add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) @@ -581,7 +606,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) ) add_subdirectory(example) if(BUILD_TESTING) - add_subdirectory(test) + add_subdirectory(test) endif() endif() diff --git a/Dockerfile b/Dockerfile index a3bf3866bf225057028d6b28c460740b129d75b9..2873a8500b6d95e36a3fe2856d28eb20e3bf8fa3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -94,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results pip3 install --upgrade pip && \ - pip3 install sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust sshtunnel==0.4.0 && \ + pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools>=75 sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/Jenkinsfile b/Jenkinsfile index 87c9457fcb4f49ddeb464532677c9a2e11e9ec69..80392bfbedd02525ff432c45328595f326fa6c54 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){ { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() } } @@ -148,7 +148,7 @@ def buildDocker(install_prefix){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs) - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' @@ -162,7 +162,7 @@ def buildDocker(install_prefix){ catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } } @@ -326,12 +326,38 @@ def cmake_build(Map conf=[:]){ if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } + //check the node gpu architecture + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - archiveArtifacts "perf_fmha_fwd_*.log" - archiveArtifacts "perf_fmha_bwd_*.log" - stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" - stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" + archiveArtifacts "perf_fmha_*.log" + if (arch_type == 1){ + stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + archiveArtifacts "perf_tile_gemm_*.log" + if (arch_type == 1){ + stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942" + } } catch(Exception err){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." @@ -486,6 +512,13 @@ def Build_CK(Map conf=[:]){ arch_type = 5 } cmake_build(conf) + if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ + echo "Run inductor codegen tests" + sh """ + pip install --verbose . + pytest python/test/test_gen_instances.py + """ + } dir("build"){ if (params.RUN_FULL_QA && arch_type == 1 ){ // build deb packages for all gfx9 targets on gfx90a system and prepare to export @@ -630,6 +663,15 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + unstash "perf_tile_gemm_log_gfx942" + unstash "perf_tile_gemm_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the GEMM performance logs: ${err.getMessage()}." + } + } if (params.RUN_FULL_QA){ // unstash perf files to master unstash "ckprofiler_0.2.0_amd64.deb" @@ -753,8 +795,8 @@ pipeline { description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", - defaultValue: false, - description: "Run the ck_tile GEMM tests (default: OFF)") + defaultValue: true, + description: "Run the ck_tile GEMM tests (default: ON)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, @@ -956,7 +998,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 tile_example_gemm_basic && \ + make -j64 tile_example_gemm_basic tile_example_gemm_universal && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ } @@ -975,7 +1017,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_gemm_basic && \ + make -j64 tile_example_gemm_basic tile_example_gemm_universal && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ } diff --git a/LICENSE b/LICENSE index 581b5efde535f686aa0a584709fccaa92353d125..68f6ae5746ddd5c79829fd7cc7e32584ffd9d822 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index c0872aa5678788a3bde2e1ff46c8b2b4fc9e30d8..95f44d887263260b89c8a73048cf4d78c3405974 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,15 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You can find instructions for running each individual example in [example](/example). +* Build and run smoke/regression examples and tests: + + ```bash + make -j smoke # tests and examples that run for < 30 seconds each + ``` + ```bash + make -j regression # tests and examples that run for >= 30 seconds each + ``` + * Build ckProfiler: ```bash @@ -153,6 +162,9 @@ Additional cmake flags can be used to significantly speed-up the build: `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available. +* `DPP_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dpp`. + These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available. + * `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances, such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on architectures like the MI100/MI200 for the functional support only. diff --git a/client_example/01_gemm/README.md b/client_example/01_gemm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6dcd1e29598a4be40f117c65b1101bda62639ccf --- /dev/null +++ b/client_example/01_gemm/README.md @@ -0,0 +1,126 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel GEMM + +## GEMM +General matrix multiplications operation. In CK GEMM operation is called as `DeviceGemm` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +For matrices with large K dimension `DeviceGemmSplitK` implementation is available. This implementation allows user to split K dimension between work groups. This implementation uses `AtomicAdd` operation on global memory, thus need to zero-out output buffer for correct results. + +For fused operations with additional tensor there are `DeviceGemmMultipleABD` or `DeviceGemmMultipleD` operation which require following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +For `DeviceGemmMultipleABD` **ALayout**, **BLayout**, **ADataType** and **BDataType** user should pass a tuple. + +List of the device operations in CK: + +* **DeviceGemmDl** - Device operation with DL instructions. +* **DeviceGemmDpp** - Device operation with DL instructions with DPP instructions during data load. +* **DeviceGemmWmma_CShuffle** - Device operation with WMMA instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffle_LdsDirectLoad** - Device operation with XDL instructions and CShuffle optimization for more optimized data store and direct load from global memory to shared memory. +* **DeviceGemm_Xdl_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV2** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. GEMM pipeline has been optimized compared to **DeviceGemm_Xdl_CShuffle**. +* **DeviceGemmXdlSkipBLds** - Device operation with XDL instructions. Load to shared memory has been skiped for B matrix. +* **DeviceGemm_Xdl_WaveletModel_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. Producer and consumer scheme cooperation between waves in workgroup. +* **DeviceGemmXdl** - Device operation with XDL instructions. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✓| + +Table of supported cases by instance factory with WMMA instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with DL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✗| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with fused output elementwise operation: + +* **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Add** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply** - bf16 (int8 for B matrix) + +* **Add + Add + Gelu** - fp16 +* **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row +* **Multiply** - fp16 +* **Add + Multiply** - fp16 +* **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Bilinear** - fp16, int8 +* **Gelu** - fp16 +* **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row, +* **Quantization** - int8 + +## GEMM V2 (Universal GEMM) +General matrix multiplications operation optimized for MI300 series. Operation is called as `DeviceGemmV2` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +This implementation allows user to split K dimension between work groups. This implementation requires AtomicAdd operation on global memory (output buffer must be set to zeroes if splitK parameter is larger than one). + +List of the device operations for in CK: + +* **DeviceGemm_Xdl_CShuffleV3** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV3R1** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. This implementation perform reduction on splitted K dimension after GEMM instead of AtomicAdd instruction. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✗| +|fp8 (C bf16)|✓| +|fp16 (A fp8)|✓| +|fp16 (B fp8)|✓| + +## Others + +* **DeviceGemm_dequantB** - GEMM with dequantization (implemented with WMMA instructions). +* **DeviceGemmMultipleD_ABScale** - GEMM with scale for A and B matrix. +* **DeviceGemmMultipleDLayernorm** - GEMM fused with layernorm. +* **DeviceGemmMultipleDMultipleR** - GEMM fused with reductions and custom global reductions operators. +* **DeviceGemmReduce** - GEMM fused with reduction. +* **DeviceGemm_Streamk_V2** - GEMM stream K implementation. Implementation allows to use reduction instead of AtomicAdd. +* **DeviceGemmStreamK** - GEMM stream K implementation using AtomicAdd. diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index c953e21d0266f61b1b9fc99de9252a4f00bd57cd..2ea31bdf068149c310b5e90647217e9000884dea 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -22,4 +22,7 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() + + add_executable(grouped_conv2d_fwd_ngchw grouped_conv2d_fwd_ngchw.cpp) + target_link_libraries(grouped_conv2d_fwd_ngchw PRIVATE composable_kernel::device_conv_operations) endif() diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..28a64ad7337f6ac1bdb218c4b48a74085b2a369d --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -0,0 +1,68 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Forward +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Forward operation is called as `DeviceGroupedConvFwdMultipleABD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **ADataType** - input data type. Pass tuple if there is fused operation with input. +* **BDataType** - weight data type. Pass tuple if there is fused operation with weight. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - Output data type. +* **AElementwiseOperation** - fused operation on tensor A (input). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (output). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution forward support tensors larger than 2GB. + +List of the device operations for grouped convolution forward in CK: + +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** - Device operation with XDL instructions. Optimized for AMD Instinct MI300 series. +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to input, weight and output. +* **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |2D, 3D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|1D, 3D| +|fp8 |3D|✗|✗| +|bf8 |3D|✗|✗| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |✗|✗|2D| +|fp16 |✗|✗|2D| +|fp32 |✗|✗|2D| +|int8 |✗|✗|2D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Dynamic elementwise operation** - 2D/3D, NHWGC, bf16/fp16/fp32/int8 +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **ConvInvScale** - 3D, NHWGC, fp8 +* **ConvScale** - 3D, NHWGC, fp8/bf8 +* **ConvScale + Add** - 3D, NHWGC, fp8 +* **ConvScale + Relu** - 3D, NHWGC, fp8 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add (for A and B)** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add + Scale + Add + Relu** - 3D, NHWGC, bf16/fp16/fp32/int8 diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp new file mode 100644 index 0000000000000000000000000000000000000000..480abf23d24747f2f4a3e93ba09f8cddbd058139 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NGCHW; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::NGKHW; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd() +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C * Hi * Wi, G * C * Hi * Wi, Hi * Wi, Wi, 1}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{K * Ho * Wo, G * K * Ho * Wo, Ho * Wo, Wo, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + // workspace_sz will be equal to 0 for other layout than NGCHW + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd(); } diff --git a/client_example/10_grouped_convnd_bwd_data/README.md b/client_example/10_grouped_convnd_bwd_data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ed133310e1f4af03e232a6177fca09ad7467240 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/README.md @@ -0,0 +1,48 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Data + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Backward Data operation is called as `DeviceGroupedConvBwdDataMultipleD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **ALayout** - output layout (NHWGK, GNHWK, NGKHW). +* **BLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **ELayout** - input layout (NHWGC, GNHWC, NGCHW). +* **ADataType** - output data type. +* **BDataType** - weight data type. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - input data type. +* **AElementwiseOperation** - fused operation on tensor A (output). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (input). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution backward data supports tensors larger than 2GB (except when image is larger than 2GB). + +List of the device operations for grouped convolution backward data in CK: + +* **DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1** - Device operation with XDL instructions and support of fused operations to input. +* **DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|2D, 3D| +|fp16 |2D, 3D|✗|2D, 3D| +|fp32 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32 diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ed3dff0f1e80dd494fa2b1c5eb051a9a65ba34d4 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -0,0 +1,62 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Weight + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. Backward weight version uses splitK feature (due to large GEMM K dimension). In CK Grouped Convolution Backward Weight operation is called as `DeviceGroupedConvBwdWeight` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **InDataType** - input data type. +* **WeiDataType** - weight data type. +* **OutDataType** - output data type. +* **InElementwiseOperation** - fused operation on tensor input. +* **WeiElementwiseOperation** - fused operation on tensor weight. +* **OutElementwiseOperation** - fused operation on tensor output. +* **ComputeTypeA** - compute data type of tensor A for mfma instruction (ADataType by default). +* **ComputeTypeB** - compute data type of tensor B for mfma instruction (ComputeTypeA by default). + +For fused operations with additional tensor there is `DeviceGroupedConvBwdWeightMultipleD` operation which requires following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +Grouped convolution backward weight doesn't supports tensors larger than 2GB. + +List of the device operations for grouped convolution backward weight in CK: + +* **DeviceGroupedConvBwdWeight_Xdl_CShuffle** - Device operation with XDL instructions. +* **DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle** - Device operation with XDL instructions. Optimized for small C or K. +* **DeviceGroupedConvBwdWeight_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to output. +* **DeviceGroupedConvBwdWeight_Dl** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|✗|✗| +|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| +|fp16 |2D, 3D|✗|1D, 2D, 3D| +|fp32 |2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |3D|✗|3D| +|int8 |3D|✗|3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16(fp32 for weight)|1D, 2D, 3D|✗|1D, 2D, 3D| +|fp16 |1D, 2D, 3D|✗|1D, 2D, 3D| +|fp32 |1D, 2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index ce5834d1e2363b8db5449a593e7ec2f9fad3772e..9e2012bf8a7ce28333ac51167d3274144164a570 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -56,7 +56,7 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() - if (GPU_TARGETS MATCHES "gfx12") + if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index c7d295de943e1feb5b139d933118db745e7edee3..7b878d0d579635b5022a7a0167c73daf68e9d134 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp new file mode 100644 index 0000000000000000000000000000000000000000..301df0a5296a2d1023711820899715f35538a376 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + TensorDesc B1{}; + TensorDesc C{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string b1_elem_op = PassThrough; + std::string c_elem_op = PassThrough; + std::string acc_elem_op = Scale; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDescGemmGemm tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b0_block_transfer{}; + operation::BlockTransferDesc b1_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + bool mask_out_upper_triangle = false; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool + IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..428034a3ba9253cdc061dcff56439486e3f0aa54 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines the problem specification for a GEMM operation +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index 359da7d8cf5ee48aab4cd6a4e987a49830fac88c..e5eeb6be1584f1c181a90b12922520aa83b9bee8 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + LoopScheduler loop_scheduler{}; + PipelineVersion pipeline_version{}; // functions to update fusion operators if provided void update_prologue(const std::string& prologue); diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index 84ef92f0a039706e1da4719ca9576667fac44494..5a51a0002e84d427176b5697560f7010bcef01ff 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -23,6 +23,26 @@ struct TileDesc int n_Xdl_per_wave = 0; int num_gemmk_prefetch_stage = 0; }; + +struct TileDescGemmGemm +{ + int block_size = 0; + int gemm01_m_per_block = 0; + int gemm0_n_per_block = 0; + int gemm0_k_per_block = 0; + int gemm1_n_per_block = 0; + int gemm1_k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int b1k1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int gemm0_m_Xdl_per_wave = 0; + int gemm0_n_Xdl_per_wave = 0; + int gemm1_n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; + struct BlockTransferDesc { std::string thread_cluster_length = ""; diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 8bad7bf89c55fccb6e052347f0b5689d20200e4a..b05e1341765d296b990dc9ce3b9a94eb61087072 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -66,6 +66,20 @@ enum class GemmType }; std::string ToString(GemmType gt); +enum class LoopScheduler +{ + Default, + Interwave, +}; +std::string ToString(LoopScheduler ls); + +enum class PipelineVersion +{ + v1, + v2 +}; +std::string ToString(PipelineVersion pv); + struct TensorDesc { DataType element; @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; +constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale"; } // namespace host } // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf140ead1df58d9810977609a959747792df2e7e --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b12c2e1a4afdee23aef1ce9ec8de24439aa20bb6 --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// calculate appropriate Gemm Specification based on input tensor dimensions +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t n1, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block, + const std::size_t n1_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) + spec += "O"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!prologue.empty()) + { + this->prologue = pro; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epilogue.empty()) + { + this->epilogue = epi; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| +// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| +// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| +// | | | | | | | | | | | Wave| Wave| Wave| | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, +// Padded fallback kernel + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, +// Irregular k + { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + // clang-format on + }; + + const std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Padded fallback kernel + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Irregular k + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + // clang-format on + }; + + const std::vector b1_block_descriptions = { + // clang-format off +// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Padded fallback kernel + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Irregular k + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 8}, + { 1, 4}, + { 1, 8}, + { 1, 4}, +// Padded fallback kernel + { 1, 2}, + { 1, 2}, +// Irregular k + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, +// Padded fallback kernel + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, +// Irregular k + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b1_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a + x.b1_block_transfer = b1_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; + x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.b1_elem_op = prob.B1ElementOp; + x.c_elem_op = prob.CElementOp; + x.acc_elem_op = prob.AccElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + prob.O, + x.tile_desc.gemm01_m_per_block, + x.tile_desc.gemm0_n_per_block, + x.tile_desc.gemm0_k_per_block, + x.tile_desc.gemm1_n_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + x.mask_out_upper_triangle = true; + result.push_back(x); + + x.mask_out_upper_triangle = false; + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + Problem prob; + prob.TransA = false; + prob.TransB = true; + prob.TransB1 = false; + prob.TransC = false; + + return {CreateOperations(prob, prologue, epilogue)}; +} + +static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, " + "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " + "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, " + "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " + "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, " + "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " + "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " + "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " + "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " + "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " + "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " + "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " + "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " + "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.b1k1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB0", ToString(this->B.layout)}, + {"LayoutB1", ToString(this->B1.layout)}, + {"LayoutC", ToString(this->C.layout)}, + {"ADataType", ToString(this->A.element)}, + {"B0DataType", ToString(this->B.element)}, + {"B1DataType", ToString(this->B1.element)}, + {"CDataType", ToString(this->C.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"AElementwiseOperation", this->a_elem_op}, + {"B0ElementwiseOperation", this->b_elem_op}, + {"Acc0ElementwiseOperation", this->acc_elem_op}, + {"B1ElementwiseOperation", this->b1_elem_op}, + {"CElementwiseOperation", this->c_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, + {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, + {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, + {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, + {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"B1K1", std::to_string(this->tile_desc.b1k1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, + {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, + {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"B0BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b0_block_transfer.thread_cluster_length}, + {"B0BlockTransferThreadClusterArrangeOrder", + this->b0_block_transfer.thread_cluster_arrange_order}, + {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order}, + {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)}, + {"B0BlockTransferSrcScalarPerVector", + std::to_string(this->b0_block_transfer.src_scalar_per_vector)}, + {"B0BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, + {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, + {"B1BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b1_block_transfer.thread_cluster_length}, + {"B1BlockTransferThreadClusterArrangeOrder", + this->b1_block_transfer.thread_cluster_arrange_order}, + {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order}, + {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)}, + {"B1BlockTransferSrcScalarPerVector", + std::to_string(this->b1_block_transfer.src_scalar_per_vector)}, + {"B1BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, + {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CBlockTransferScalarPerVector_NWaveNPerXdl", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)}, + }; + + return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fff75c196263961394445bd27162968a563e112b..fe556615e0897131cc5f03c9c5682e35f5b0cb9c 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) // accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } +// clang-format off +// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, + +// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +// clang-format on + // Hard-code tuning parameters in modularized fashion, string them together into a vector of // instances std::vector Operation_Xdl_CShuffle::CreateOperations( @@ -83,6 +89,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, +// Irregular tile + { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, // clang-format on }; @@ -100,6 +108,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -109,15 +119,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, }; std::vector b_block_descriptions_rowmajor = { @@ -134,6 +146,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on }; @@ -151,6 +165,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -167,6 +183,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 1, 1}, { 1, 1}, { 1, 1}, + { 1, 1}, { 1, 1}, // clang-format on }; @@ -185,6 +202,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<1, 16, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, +// Irregular tile + { S<1, 16, 1, 4>, 1}, // clang-format on }; @@ -199,33 +218,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); - // Put all values together into a single operation > store into the result vector - for(std::size_t i = 0; i < tile_descriptions.size(); i++) + const std::vector> scheduler_pipeline_descriptions = + { + {LoopScheduler::Default, PipelineVersion::v1}, + {LoopScheduler::Interwave, PipelineVersion::v1}, + {LoopScheduler::Default, PipelineVersion::v2}, + }; + for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions) { - Operation_Xdl_CShuffle x; - x.tile_desc = tile_descriptions[i]; - x.a_block_transfer = a_block_descriptions[i]; - x.b_block_transfer = b_block_descriptions[i]; - x.cshuffle = cshuffle_descriptions[i]; - x.c_block_transfer = c_block_descriptions[i]; - x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; - x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; - x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; - x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { - return TensorDesc{dt, ToLayout(trans)}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; - x.gemm_specialization = GetGemmSpec(prob.M, - prob.N, - prob.K, - x.tile_desc.m_per_block, - x.tile_desc.n_per_block, - x.tile_desc.k_per_block); - x.update_prologue(prologue); - x.update_epilogue(epilogue); - result.push_back(x); + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.loop_scheduler = loop_scheduler; + x.pipeline_version = pipeline_version; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } } return result; } @@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " - "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; // use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const @@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDEBlockTransferScalarPerVector_NPerBlock", std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"LoopScheduler", ToString(this->loop_scheduler)}, + {"PipelineVersion", ToString(this->pipeline_version)}, }; return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 5b0c929db32fc9a834ec040163b2a5476e2b6d1c..452cd998469702603ed2f33e931ba2622e16ba87 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/headers.hpp" #include "ck_headers.hpp" diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index a8a8b10c04d522e93dc7340167e46ca83f51259b..a60e36ca4a4ba42b4fec053bc8a2f52e7a1e932c 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/types.hpp" #include "ck/host/stringutils.hpp" #include @@ -56,6 +59,26 @@ std::string ToString(GemmType gt) throw std::runtime_error("Incorrect gemm type"); } +std::string ToString(LoopScheduler ls) +{ + switch(ls) + { + case LoopScheduler::Default: return "ck::LoopScheduler::Default"; + case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave"; + } + throw std::runtime_error("Incorrect LoopScheduler type"); +} + +std::string ToString(PipelineVersion pv) +{ + switch(pv) + { + case PipelineVersion::v1: return "ck::PipelineVersion::v1"; + case PipelineVersion::v2: return "ck::PipelineVersion::v2"; + } + throw std::runtime_error("Incorrect PipelineVersion type"); +} + std::string SequenceStr(const std::vector& v) { return "ck::Sequence<" + diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index bd7ef463fbe64d5bc3d07665cb4757598657f2ad..9e2d990d9bf5b4ee2434dd9b9700e02e57cde1e4 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 50290fa25ad385ce1e657de8cac3042227cc6787..9902caab0496eaa7a182ff2c7896992bc9908cdf 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index b558d97c783c7f9a2a0901013c5ea4cee053d175..205283e7aad2bd94f09d50063d67ecc1e1bd9ed3 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index e2972a93d2f7e11aa6b03dc35b7aae6663e70f93..2b83af24321dc021b35386cc2f25b4ca7da7d102 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index b728096c51e80d0d72f407160275e3699cf5a16a..fbe27e9c8b82f9b5ddf339a11bfc4d5e3cf92c80 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 99d4c6497331f65d19adf302bd47dbaa22ac4b40..24fde2e52358688f1c9ab4ba9d68cd47a0d9a76a 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include #include diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index c4413b47be2b23a36dd2a631794876dae8b98776..a49714f7c6850fd83c443592539c3f6e4a0beded 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index e962d4cd3e1e1573b13272e052eeee646b05dec1..3163bb08edad50baf19566de420c9aa252ed066e 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #include #include +#include #include #include diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9f38e90416e0d2363a921df1ac4268bbc82e55ff..b1ee729f77518f2ddf312c2965454e43e589d8cb 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp index 92edf1262832d5e69c5751b162d9b5a43aac5a58..52b94d4b70ba3eecb8f2f9b20bf7d7e39e6fa2e0 100644 --- a/codegen/test/rtc/include/rtc/manage_ptr.hpp +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index a0a2cb9b77480f7c32fb531f77a8ad049024dab2..2f3b26cc43549c7e21b84366e5a2d1eb80f203b4 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 8cb71b9043cb92c675ce421d668f95a8886291c2..5a70f898e8cd0b0d97c696d0bf4b41dc290db1f6 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 747f83e3baa240159adcf2e89847f4a1bad245a8..6f16e36720954a7908e3c3ecece4732115797cb1 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 9fe38e84ad6624bcb82d8f3f97a0767ecd92108c..982e95de172fcb2b02633d6f9a6daf7758bf80ba 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include +#include #include // extern declare the function since hip/hip_ext.h header is broken diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 4e89bc35399075d67dcbc03621c1445c6eb6f66b..b36b17cce1cb50a7e14e06f4956f771373570014 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 46a61a87fc777ca31ab07b4e0a80dca42edba45f..e9df8c9f5ff144152a18c09b6b003294255a7350 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.12.0 +rocm-docs-core==1.15.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index c2e74baae376ec7a23bcd4a09099f16d0fda34aa..a42fdf09bf47e7e86533774a3ea18d5ac7eb0608 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -8,6 +8,13 @@ accessible-pygments==0.0.5 # via pydata-sphinx-theme alabaster==0.7.16 # via sphinx +asttokens==3.0.0 + # via stack-data +attrs==24.3.0 + # via + # jsonschema + # jupyter-cache + # referencing babel==2.15.0 # via # pydata-sphinx-theme @@ -25,9 +32,17 @@ cffi==1.16.0 charset-normalizer==3.3.2 # via requests click==8.1.7 - # via sphinx-external-toc + # via + # jupyter-cache + # sphinx-external-toc +comm==0.2.2 + # via ipykernel cryptography==43.0.0 # via pyjwt +debugpy==1.8.12 + # via ipykernel +decorator==5.1.1 + # via ipython deprecated==1.2.14 # via pygithub docutils==0.21.2 @@ -38,20 +53,56 @@ docutils==0.21.2 # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +exceptiongroup==1.2.2 + # via ipython +executing==2.1.0 + # via stack-data fastjsonschema==2.20.0 - # via rocm-docs-core + # via + # nbformat + # rocm-docs-core gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via rocm-docs-core +greenlet==3.1.1 + # via sqlalchemy idna==3.7 # via requests imagesize==1.4.1 # via sphinx +importlib-metadata==8.6.1 + # via + # jupyter-cache + # myst-nb +ipykernel==6.29.5 + # via myst-nb +ipython==8.31.0 + # via + # ipykernel + # myst-nb +jedi==0.19.2 + # via ipython jinja2==3.1.4 # via # myst-parser # sphinx +jsonschema==4.23.0 + # via nbformat +jsonschema-specifications==2024.10.1 + # via jsonschema +jupyter-cache==1.0.1 + # via myst-nb +jupyter-client==8.6.3 + # via + # ipykernel + # nbclient +jupyter-core==5.7.2 + # via + # ipykernel + # jupyter-client + # nbclient + # nbformat latexcodec==3.0.0 # via pybtex markdown-it-py==3.0.0 @@ -60,16 +111,48 @@ markdown-it-py==3.0.0 # myst-parser markupsafe==2.1.5 # via jinja2 +matplotlib-inline==0.1.7 + # via + # ipykernel + # ipython mdit-py-plugins==0.4.1 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser==3.0.1 +myst-nb==1.1.2 # via rocm-docs-core +myst-parser==3.0.1 + # via myst-nb +nbclient==0.10.2 + # via + # jupyter-cache + # myst-nb +nbformat==5.10.4 + # via + # jupyter-cache + # myst-nb + # nbclient +nest-asyncio==1.6.0 + # via ipykernel packaging==24.1 # via + # ipykernel # pydata-sphinx-theme # sphinx +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +platformdirs==4.3.6 + # via jupyter-core +prompt-toolkit==3.0.50 + # via ipython +psutil==6.1.1 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data pybtex==0.24.0 # via # pybtex-docutils @@ -87,26 +170,45 @@ pygithub==2.3.0 pygments==2.18.0 # via # accessible-pygments + # ipython # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.8.0 # via pygithub pynacl==1.5.0 # via pygithub +python-dateutil==2.9.0.post0 + # via jupyter-client pyyaml==6.0.1 # via + # jupyter-cache + # myst-nb # myst-parser # pybtex # rocm-docs-core # sphinx-external-toc +pyzmq==26.2.0 + # via + # ipykernel + # jupyter-client +referencing==0.36.1 + # via + # jsonschema + # jsonschema-specifications requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.12.0 +rocm-docs-core==1.15.0 # via -r requirements.in +rpds-py==0.22.3 + # via + # jsonschema + # referencing six==1.16.0 - # via pybtex + # via + # pybtex + # python-dateutil smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 @@ -116,6 +218,7 @@ soupsieve==2.5 sphinx==7.4.7 # via # breathe + # myst-nb # myst-parser # pydata-sphinx-theme # rocm-docs-core @@ -149,15 +252,43 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx +sqlalchemy==2.0.37 + # via jupyter-cache +stack-data==0.6.3 + # via ipython +tabulate==0.9.0 + # via jupyter-cache tomli==2.0.1 # via sphinx +tornado==6.4.2 + # via + # ipykernel + # jupyter-client +traitlets==5.14.3 + # via + # comm + # ipykernel + # ipython + # jupyter-client + # jupyter-core + # matplotlib-inline + # nbclient + # nbformat typing-extensions==4.12.2 # via + # ipython + # myst-nb # pydata-sphinx-theme # pygithub + # referencing + # sqlalchemy urllib3==2.2.2 # via # pygithub # requests +wcwidth==0.2.13 + # via prompt-toolkit wrapt==1.16.0 # via deprecated +zipp==3.21.0 + # via importlib-metadata diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt old mode 100644 new mode 100755 index 957acce165edf371d259fad119fdcf867af19256..97ac21eba5a3722c1f3127e00cdcdb79e01b2634 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -29,10 +29,16 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) +add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) +add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) +add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) @@ -42,9 +48,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) @@ -58,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index a3a62d4cfa6506e5b9d70c07e0fb1fc9651eae4f..9664c50b6e11ca846d11118de648b908d91cb3b6 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -287,3 +287,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp old mode 100644 new mode 100755 diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7b491173a6db282811b6733626f4d632ab2d914d --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp deleted file mode 100644 index 108c100cbdf88c0a8e33e9d372daddf2a56894ab..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_xdl_bf16_rtn.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/utility/type_convert.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" - -using ADataType = ck::bhalf_t; -using BDataType = ck::bhalf_t; -using CDataType = ck::bhalf_t; -using AccDataType = float; -using CShuffleDataType = float; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -using ReferenceComputeType = float; -using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp new file mode 100755 index 0000000000000000000000000000000000000000..5b56a43483b85a8f1b9da07ea394feadd2c682b8 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, + 64, 8, 8, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 07d51855d6b60bfa32b4e815627dfbe64a0006d8..414683ffdf63893f95629d52f78a7a95a733b9c4 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; -// // clang-format on -// clang-format off using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp index 2e27fc66f9456b67ff860e5926f6a24149ab14c3..b0e36b394bb217ea43998923c4f68a75fe413e98 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" -using ADataType = ck::f8_t; -using BDataType = ck::half_t; +using ADataType = ck::half_t; +using BDataType = ck::f8_t; using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; @@ -29,15 +29,15 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, - 64, 16, 8, + 256, 8, 16, 16, 16, 1, 1, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 16, 16, 0, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8a40baa8ad3b30ef3f7808accc0d28b47684fd8 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 32, + 32, 32, + 4, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp index ad370f570efd98e90c2bd53fe7522e5ee249586a..4a969246cd80d3aa6bd27cdfb556ca37368fe091 100644 --- a/example/01_gemm/gemm_xdl_fp16_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -27,17 +27,17 @@ using DeviceGemmV2Instance = ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 256, - 224, 256, - 64, 8, 2, + 64, + 16, 16, + 256, 8, 8, 16, 16, - 7, 8, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 8, 2, 0, - 1, 2, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp old mode 100644 new mode 100755 index 5a02457dafd1e021b2c0fa71bd1498c891135304..dbdf7199e857969f4cc3b8af5ae7ea56e696bd97 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -15,7 +15,6 @@ using F16 = ck::half_t; using ALayout = Row; using BLayout = Row; -// using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3ee6e26856f1eab17c7d588305e3b6da7e7765a6..4371af6244cb6c8d1d2288f174d4fcb917618743 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -5,88 +5,6 @@ #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc old mode 100755 new mode 100644 index 04243b8291287e483cdd6362599876ac2050274e..9ee380d247c2b368493fa03ba627d55c0dcb39c6 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 5b6969f1d9c9a97e9cdb419b6dfc667892f62985..2b60fa5d2867055f841a8bb749d0ab1a910da5f1 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index be47665a262ec6619816c249bdfdb96ba3c8ae16..aa9367cdcfc83420fa3015ade75c3090df5bd9de 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc03664043b1a0295f9f1dcfb38a77002..018b57f82c5d23a53eb4aabda785065e8041686d 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 720af39af645e622cba1897a46fb1f7004516dae..d5157209449ba15cb9956bd9c04c78ef36b9fc27 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -22,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() + +add_example_executable(example_batched_gemm_xdl_fp16int4_b_scale_v3 batched_gemm_xdl_fp16int4_b_scale_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16int4_b_scale_v3) diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..42171bcdb7f16d8368c1cdd259825db91bfed7eb --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -0,0 +1,82 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = F32; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PermuteA = false; +static constexpr bool PermuteB = false; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 256; + +// clang-format off +using DeviceBatchedGemmV2Instance = + ck::tensor_operation::device::DeviceBatchedGemm_Xdl_CShuffleV3_BScale< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 1, + S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm; +#include "run_batched_gemm_example_fp16int4_b_scale.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_fp16_int4_b_scale_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc new file mode 100644 index 0000000000000000000000000000000000000000..8c4913dbccd996d322e6a7eda7d736cfcbefe281 --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -0,0 +1,578 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once +struct ProblemSize final +{ + ck::index_t M = 128; + ck::index_t N = 128; + ck::index_t K = 384; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + // Batched Gemm count + ck::index_t batch_count = 2; + + // Split K count + ck::index_t KBatch = 1; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + KBatch] = problem_size; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t batch_BScale_Stride = + ((K + Scale_Block_K - 1) / Scale_Block_K) * ((N + Scale_Block_N - 1) / Scale_Block_N); + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b_g_k_n_permute( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b1_g_k_n( + f_host_tensor_descriptor(batch_count, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + batch_BScale_Stride, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + printf("a_g_m_k size: %zu, b_g_k_n size: %zu, b1_g_k_n size: %zu, c_g_m_n size: %zu\n", + a_g_m_k.mDesc.GetElementSpaceSize(), + b_g_k_n_permute.mDesc.GetElementSpaceSize(), + b1_g_k_n.mDesc.GetElementSpaceSize(), + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + printf("Permute B\n"); + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int bs = 0; bs < batch_count; bs++) + { + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_g_k_n_permute(bs * batch_stride_B + j * N * K1 + i * K1 + jj) = + b_g_k_n(bs * batch_stride_B + i * K + (j * K1 + jj)); + } + } + } + } + } + else + { + b_g_k_n_permute = b_g_k_n; + } + + // vector pk_i4x4 permute + for(int bs = 0; bs < batch_count; bs++) + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 6, i) = i4x2; + } + } + } + } + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b_g_k_n_device_buf.ToDevice(b_g_k_n_permute.mData.data()); + b1_g_scale_device_buf.ToDevice(b1_g_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceBatchedGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + stride_A, + stride_B, + stride_C, + Scale_Stride_BN, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_BScale_Stride, + static_cast(b1_g_scale_device_buf.GetDeviceBuffer()), + batch_count, // batch count + KBatch, // split K count + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + Tensor b_g_k_n_dequant({batch_count, K, N}); + if(config.do_verification) + { + float v_b = 0; + for(int bs = 0; bs < batch_count; bs++) + { + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_g_k_n_dequant(bs, k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_g_k_n(bs, k / Scale_Block_K, n / Scale_Block_N)); + } + } + } + + auto ref_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_g_m_k, + b_g_k_n_dequant, + c_g_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + hip_check_error(hipDeviceSynchronize()); + + c_g_m_n_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + +#if 0 + // print A matrix + printf("A matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&a_g_m_k(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < K; j++) + { + printf("%.2f,", static_cast(a_g_m_k(bs, i, j))); + } + printf("\n"); + } + } + + // print B matrix original + printf("B matrix original:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n(bs, 0, 0))); + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + printf("%d,", static_cast(i4)); + } + printf("\n"); + } + } + + // print B matrix + printf("B matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n_dequant(bs, 0, 0))); + for(int i = 0; i < K; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(b_g_k_n_dequant(bs, i, j))); + } + printf("\n"); + } + } + + // print B scale matrix + printf("B Scale matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b1_g_k_n(bs, 0, 0))); + for(int i = 0; i < (K + Scale_Block_K - 1) / Scale_Block_K; i++) + { + for(int j = 0; j < (N + Scale_Block_N - 1) / Scale_Block_N; j++) + { + printf("%.2f, ", static_cast(b1_g_k_n(bs, i, j))); + } + printf("\n"); + } + } + + // print C matrix + printf("C matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf( + "batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_device_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_device_result(bs, i, j))); + } + printf("\n"); + } + } + + printf("C reference matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_host_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_host_result(bs, i, j))); + } + printf("\n"); + } + } +#endif + + return pass; +} + +bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 128 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 256 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 7) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + if(argc >= 8) + { + problem_size.batch_count = std::stoi(argv[7]); + } + + if(argc >= 9) + { + problem_size.KBatch = std::stoi(argv[8]); + } + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + return run_batched_gemm(problem_size, config); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index e3370b880bb29e78a88e5ed4ca1fcdcd14afa2f8..ce42a20be78940a9d6203ac9e8888d4b0bfe4910 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +#if defined(CK_USE_AMD_MFMA_GFX950) +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; +#else // defined(CK_USE_AMD_MFMA_GFX950) template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -80,6 +130,7 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; +#endif // defined(CK_USE_AMD_MFMA_GFX950) template using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +using ScaleDataType = ck::e8m0_bexp_t; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct ExecutionConfig final +{ + int do_verification = 1; // (0=no, 1=CPU) + int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + bool time_kernel = false; // (0=no, 1=yes) + int verbosity = 0; // (0=no info, 1=verbose info) +}; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; +}; + +bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.StrideA = std::stoi(argv[8]); + problem_size.StrideB = std::stoi(argv[9]); + problem_size.StrideC = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl + << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} + +template +bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using ELayout = CLayout; + using DsLayout = ck::Tuple<>; + using DsDataType = ck::Tuple<>; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = CElementWiseOp; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +#if 1 + // XXX: These parameters should not exist in MX-native GEMM kernel + static constexpr ck::index_t Scale_Block_M = 128; + static constexpr ck::index_t Scale_Block_N = 128; +#endif + static constexpr ck::index_t Scale_Block_K = MXVectorSize; + + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA + // instructions. + // + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized + // scaled type convert functions. + // + // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to + // ScaleBlockK (aka MXVectorSize). + // Additionally, the following is also expected: + // static_assert(ScaleBlockM % MPerBlock == 0); + // static_assert(ScaleBlockN % NPerBlock == 0); + // In MX-native GEMM kernel these requirements should be relaxed. + // + // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 + // MfmaSelector::selected_mfma.k_per_blk = + // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 + // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float + + // clang-format off + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| + // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | + // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; + // clang-format on + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1}); + } + else + { + return HostTensorDescriptor({row, col}, {1, stride}); + } + }; + + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + if(K % Scale_Block_K != 0) + { + throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale( + f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification + Tensor c_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host + + if(config.verbosity >= 0) + { + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; + } + + switch(config.init_method) + { + case 0: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + break; + case 1: + case 2: + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + if(config.verbosity > 0) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {0.5}" << std::endl; + std::cout << "Init B = {1}" << std::endl; + std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + default: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + } + + if(config.verbosity > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(config.verbosity > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideC, + a_scale_device_buf.GetDeviceBuffer(), + b_scale_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong!\n" + "Provided combination of compilation and runtime parameters is " + "not consistent with the supported device_gemm arguments."); + } + + if(config.verbosity > 0) + std::cout << "Computing GEMM on device..." << std::endl; + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + + bool res_verified = true; + if(config.do_verification > 0) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Computing GEMM on host..." << std::endl; + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Comparing results..." << std::endl; + } + + if(config.init_method == 1) + { + res_verified = + res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; + std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + } + + res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!"); + + if(config.verbosity > 0 && res_verified) + std::cout << "Done." << std::endl; + } + else + { + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + } + + if(config.time_kernel) + { + std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + return res_verified; +} + +template +bool run_mx_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_mx_gemm(problem_size, config); +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e21698ec41d5b4432e0a38b6d27ea5f223751e --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +#if 1 +// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type +using XDataType = float; +#else +using XDataType = ck::e8m0_bexp_t; +#endif +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 128; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 72759916af49282fa97cf8535a124ea07900c2fd..bcb62df62570bded896594e65ff4cc081b7ed12e 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,14 @@ include_directories(BEFORE add_custom_target(examples) + +# list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds) +# all other tests are labelled as SMOKE_EXAMPLE +set(REGRESSION_EXAMPLES + example_sparse_embedding3_forward_layernorm +) + + function(add_example_dependencies EXAMPLE_NAME FILE_NAME) if(FILE_NAME) add_dependencies(EXAMPLE_NAME FILE_NAME) @@ -15,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example source file ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -54,9 +62,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() - #Do not build any DPP examples if DL_KERNELS not set + #Do not build any DPP examples if DPP_KERNELS not set foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dpp") + if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") message("removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -75,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any microscaling examples if gfx950 target is not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing microscaling example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any FP8 examples if CK_ENABLE_FP8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") @@ -94,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) + elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -107,6 +124,15 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 0) endif() #message("add_example returns ${result}") + if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + #message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}") + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") + add_dependencies(smoke ${EXAMPLE_NAME}) + elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + #message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}") + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") + add_dependencies(regression ${EXAMPLE_NAME}) + endif() set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) @@ -178,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -188,8 +214,10 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() + #message("add_example returns ${result}") set(result ${result} PARENT_SCOPE) + endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1ba76a523eed58b60aa67b2a58b9af0825647f06..9ba3a453fc1ed0dc76ea8dd93b47de5002630019 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -102,6 +102,11 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally specify the use of OCP_FP8 +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index c7ab296c3bbedd5c1358e45cf28fb29b8786c6f5..e9806e7a67ad3782444baf5b7c6a41c8a639ab0f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -15,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ## kernel The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. -There are 3 template parameters for this kernel template. -* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 83a1e82d6d23e57b522b59033609fa2dd6a84da9..c05660c8ab8ddc16ae0421d2c5091e3c5c784653 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 66814f5a16a459e22c2b1523981350072a1a7c4f..ad8daba17ed82417d50201549ea8a5ce53e5cd83 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -29,11 +29,6 @@ K0_MAX_SUBMAX_MAP = { 256: 256 } -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", -} - FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -90,9 +85,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; + ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; @@ -329,12 +322,6 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - @property def template(self) -> str: kernel_body = str() @@ -374,13 +361,12 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property @@ -501,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index fb998a33d7a04bb4917272460f67e74cc3789e68..2f20819302f85eb2ceec358177861e3782daac21 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< fmha_pipeline_problem_{F_idx}>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdAppendKVKernel, - fmha_pipeline_{F_idx}>; +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; @@ -355,4 +353,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 2f7edd54773fc4468552e06e13972e7001c33288..37745dd38299e30eb78a756e714f710102788f48 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -48,8 +48,8 @@ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_mask_{F_idx} = {F_mask}; namespace {{ -template -struct kernel_runner {{ +template +struct instance {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_shape = ck_tile::TileFmhaShape; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< @@ -96,9 +97,7 @@ using fmha_epilogue = {F_spad}, {F_dvpad}>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -117,28 +116,50 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wtautological-compare" + +namespace {{ +template +void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} +}} +}} // anonymous namespace + +#pragma clang diagnostic pop + template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // we don't check every seqlen_k values for kvcache if (a.seqlen_k_ptr != nullptr) {{ - kernel_runner::run(s, a); + run_instance(s, a); // make sure F_bn0 is divisible by F_bk1 }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - kernel_runner::run(s, a); + run_instance(s, a); }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} template<> std::string fmha_fwd_splitkv_get_name_() {{ - using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -148,7 +169,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ template -struct kernel_runner {{ +struct instance {{ using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, {F_dvpad}, {F_lse}, @@ -176,11 +197,7 @@ using fmha_epilogue = false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel< - ck_tile::FmhaFwdSplitKVCombineTilePartitioner< - fmha_pipeline_problem::kM0, fmha_pipeline_problem::kN1>, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVCombineKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -202,22 +219,22 @@ template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if (a.num_splits <= 8) {{ - kernel_runner<3>::run(s, a); + instance<3>::run(s, a); }} else if (a.num_splits <= 16) {{ - kernel_runner<4>::run(s, a); + instance<4>::run(s, a); }} else if (a.num_splits <= 32) {{ - kernel_runner<5>::run(s, a); + instance<5>::run(s, a); }} else if (a.num_splits <= 64) {{ - kernel_runner<6>::run(s, a); + instance<6>::run(s, a); }} else if (a.num_splits <= 128) {{ - kernel_runner<7>::run(s, a); + instance<7>::run(s, a); }} }} template<> std::string fmha_fwd_splitkv_combine_get_name_() {{ - using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 08d263da9151137f180c85c3a1adc5d4e770960f..b3855e59dfc1ff149f9c365da21e64b18c07424c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1131,15 +1131,16 @@ bool run(const ck_tile::ArgParser& arg_parser) { // NOTE: use gpu to do validation ck_tile::naive_attention_fwd_traits naive_t; - naive_t.q_type = data_type; - naive_t.k_type = data_type; - naive_t.v_type = data_type; - naive_t.o_type = data_type; - naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; - naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; - naive_t.variation = 0; // TODO? + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 0e821ed5d92ee1eec9b1e36716ed7c8ae1287faf..765c221a7b17630aa3e09786e9430c88df45069c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -400,8 +400,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } }(); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - return ck_tile::make_tuple(kargs, grids); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } } template @@ -500,8 +510,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) } }(); - dim3 grids = - Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); return ck_tile::make_tuple(kargs, grids); } diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 5b1b6664ccc05880e2bec2fc5b5b0f7daf1a65f9..a0fb42aa11888b007fbe3702f3c00d1d26ce9647 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -103,7 +103,8 @@ if __name__ == "__main__": required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration" + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration" ) args = parser.parse_args() diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index 1bf74bc0553296f498004c024477034dd31797d0..fa69ac0f7ac8b2f3044e813cb22a55e294eec832 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress) target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 3573d70cd2615be0e1cc65a62613e4de0015f636..817f62dae7a6c876d4879581546d81e4fd144355 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -59,7 +59,7 @@ args: -kname print kernel name or not (default:1) -prec_i input precision (default:fp16) -prec_o output precision, set auto will be the same as input (default:auto) - -prec_sx output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) @@ -69,7 +69,7 @@ args: ``` ## limitations -Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. +Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. ``` # some case diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index ca9e432a4f1ff914abaa851ffc0dacf7d0716f56..700b007fad5990a130e76eb9917feb46d7d3d085 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse @@ -23,6 +23,10 @@ def get_if_str(idx, total, lase_else = True): else: return 'else if' +XBIAS_ENUM_STR_MAP = [ + 'no', + 'xbias'] # pre-norm add bias + FUSED_ADD_ENUM_STR_MAP = [ 'no', 'pras', # pre-norm @@ -35,7 +39,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [ DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::fp16_t', 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t'} + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t'} def BOOL_MAP(b_) -> str: if b_: @@ -48,7 +53,7 @@ class layernorm_fwd_codegen: // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct layernorm2d_fwd_traits_ { using XDataType = ck_tile::remove_cvref_t; using YDataType = ck_tile::remove_cvref_t; - using XScaleDataType = ck_tile::remove_cvref_t; + using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; @@ -120,14 +127,16 @@ struct layernorm2d_fwd_traits_ static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kXbias = kXbias_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; }; template using traits_ = layernorm2d_fwd_traits_; """ API_COMMON_HEADER = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "layernorm2d_fwd.hpp" @@ -177,26 +190,29 @@ float layernorm2d_fwd_(const S& s, A a) {{ using XDataType = typename Traits_::XDataType; using YDataType = typename Traits_::YDataType; - using XScaleDataType = typename Traits_::XScaleDataType; + using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType; - using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kXbias), static_cast(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< - typename LayerNormTypeConfig::XDataType, - typename LayerNormTypeConfig::GammaDataType, - typename LayerNormTypeConfig::BetaDataType, - typename LayerNormTypeConfig::ComputeDataType, - typename LayerNormTypeConfig::YDataType, - typename LayerNormTypeConfig::MeanDataType, - typename LayerNormTypeConfig::InvStdDataType, - typename LayerNormTypeConfig::XScaleDataType, - typename LayerNormTypeConfig::YScaleDataType, + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::XBiasDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::SmoothScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, typename Traits_::Shape, PipelineTraits>; @@ -204,12 +220,13 @@ float layernorm2d_fwd_(const S& s, A a) using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; using Pipeline = std::conditional_t; - using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; - using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + static constexpr bool UseRawStore = sizeof(YDataType) == 4; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; @@ -233,7 +250,7 @@ float layernorm2d_fwd_(const S& s, A a) API_BASE = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "layernorm2d_fwd.hpp" @@ -269,12 +286,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, INSTANCE_BASE = """ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "layernorm2d_fwd_api_common.hpp" // clang-format off -// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep {F_instance_def} // clang-format on @@ -284,6 +301,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, self.working_path = working_path self.kernel_filter = kernel_filter + class k_xbias_enum(IntEnum): + F_NO_XBIAS = 0 + F_ADD_XBIAS = 1 + class k_fuesd_add_enum(IntEnum): F_NO_ADD = 0 F_PRE_ADD = 1 @@ -299,6 +320,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_kPadN : bool F_kSaveMeanInvStd : bool F_kTwoPass : bool + F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum @@ -315,6 +337,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, @dataclass class k_problem: F_XDataType : str + F_XBiasDataType : str F_GammaDataType : str F_BetaDataType : str F_ComputeDataType : str @@ -352,7 +375,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, class h_traits: F_XDataType : str F_YDataType : str - F_XScaleDataType : str + F_SmoothScaleDataType : str F_YScaleDataType : str F_Repeat_M : int F_Repeat_N : int @@ -362,15 +385,17 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_kPadN : bool F_kSaveMeanInvStd_ : bool F_kFastFDiv_ : bool + F_kWelford_ : bool F_kTwoPass_ : bool + F_kXbias_ : int F_kFusedAdd : int F_kFusedQuant : int @property def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' - t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ # string when calling this kernel @@ -388,6 +413,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, class h_instance: F_DataTypePair : str F_N : str + F_xbias : int F_add : int F_sweep : int instance_list : List[Any] # List[h_traits] @@ -397,6 +423,8 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, prec_i, prec_o = self.F_DataTypePair.split(',') dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_xbias != 0: + nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias] if self.F_add != 0: nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: @@ -422,11 +450,10 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, def name_common_header(self) -> str: return 'layernorm2d_fwd_api_common' - @property - def content_api(self) -> str: + def content_api(self, args) -> str: # 1 sort based on dtype t_dtype_dict = dict() - blobs = self.get_blobs() + blobs = self.get_blobs(args) for blob in blobs: if blob.F_DataTypePair not in t_dtype_dict: t_dtype_dict[blob.F_DataTypePair] = {} @@ -451,19 +478,19 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, if ins.F_kFusedQuant == 0: _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sx == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_XScaleDataType, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) elif ins.F_kFusedQuant == 2: _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) - _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( - f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd, f_sweep_cond = _sweep_cond) inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str - n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' - n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) prec_i, prec_o = dtype_.split(',') d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) @@ -474,77 +501,80 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, def content_common_header(self) -> str: return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) - def get_blobs(self): + def get_blobs(self, args): h_traits = layernorm_fwd_codegen.h_traits h_instance = layernorm_fwd_codegen.h_instance - dynamic_quant_out_dtype = ['int8'] + dynamic_quant_out_dtype = ['int8', 'fp8'] # some predefined support range # (prec_i,prec_o) for simplicity this string will be used as key for dict scale_list = [('fp32,fp32')] dtype_list = [('fp16,fp16'), ('bf16,bf16'), - ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + ('fp16,int8'), ('bf16,int8'), + ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out + types_8bit = ('int8', 'fp8') + types_16bit = ('int16', 'fp16', 'bf16') #fused_add_list = [0, 1, 2] #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + xbias_list = [0, 1] fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant - - # rm rn tm tn vn pd mv fdiv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} + # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): prec_i, prec_o = dtype.split(',') - scale_x, scale_y = scale_type.split(',') + scale_sm, scale_y = scale_type.split(',') if prec_o in dynamic_quant_out_dtype and fused_quant != 1: continue # skip non dynamic quant case if fused_quant == 1 and hs_key == 'big': @@ -554,20 +584,32 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_ = copy.copy(chs_) # copy the base instance out h_.F_XDataType = prec_i h_.F_YDataType = prec_o - h_.F_XScaleDataType = scale_y - h_.F_YScaleDataType = scale_x + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_kXbias = xbias h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant + # disable welford update for 8bit and 16 bit smallN + if not h_.F_kTwoPass_: + #disable 16 bit when set args disable_16b_welford + if args.disable_16b_welford and prec_i in types_16bit: + h_.F_kWelford_ = False + #disable 8bit by default + elif prec_i in types_8bit or prec_o in types_8bit: + h_.F_kWelford_ = False + #disable 16bit small N + elif prec_i in types_16bit and hs_key == '64': + h_.F_kWelford_ = False current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs)) return total_blob - def list_blobs(self) -> None: + def list_blobs(self, args) -> None: w_p = Path(self.working_path) list_p = w_p / 'layernorm2d_fwd_blobs.txt' - blobs = self.get_blobs() + blobs = self.get_blobs(args) with list_p.open('w') as list_f: # api related file list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") @@ -576,11 +618,12 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") - def gen_blobs(self) -> None: + def gen_blobs(self, args) -> None: w_p = Path(self.working_path) - (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + w_str = self.content_api(args) + (w_p / (self.name_api + ".cpp")).write_text(w_str) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) - blobs = self.get_blobs() + blobs = self.get_blobs(args) for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) @@ -588,14 +631,14 @@ def list_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args) def gen_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -663,6 +706,13 @@ if __name__ == "__main__": help="codegen receipt." ) + parser.add_argument( + "--disable_16b_welford", + default=False, + required=False, + help="enable/disable welford for 16bit datatype n > 64" + ) + args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b49c04619d54c6c401128739a236259e04c54dd7..b72485222e6a791d5e270db650be1725fff161e6 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -20,6 +20,14 @@ auto get_elimit() return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1.0; + return ck_tile::make_tuple(rtol, atol); +} + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -35,12 +43,13 @@ auto create_args(int argc, char* argv[]) .insert("kname", "1", "print kernel name or not") .insert("prec_i", "fp16", "input precision") .insert("prec_o", "auto", "output precision, set auto will be the same as input") - .insert("prec_sx", + .insert("prec_sm", "auto", "output quant scale type, set auto will use fp32. used when fquant=1") .insert("prec_sy", "auto", "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd") .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") @@ -52,7 +61,7 @@ auto create_args(int argc, char* argv[]) template bool run(const ck_tile::ArgParser& arg_parser) @@ -74,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) float epsilon = arg_parser.get_float("e"); std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sm = arg_parser.get_str("prec_sm"); std::string prec_sy = arg_parser.get_str("prec_sy"); if(prec_o == "auto") { prec_o = prec_i; } - if(prec_sx == "auto") + if(prec_sm == "auto") { - prec_sx = "fp32"; + prec_sm = "fp32"; } if(prec_sy == "auto") { @@ -93,20 +102,25 @@ bool run(const ck_tile::ArgParser& arg_parser) int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); + int xbias = arg_parser.get_int("xbias"); int fused_add = arg_parser.get_int("fadd"); int fused_quant = arg_parser.get_int("fquant"); - if(fused_quant == 1 && prec_o != "int8") + if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8") { - std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl; + std::cout + << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases." + << std::endl; return false; } assert(x_stride >= n); - using TypeConfig = LayerNormTypeConfig; + using TypeConfig = + LayerNormTypeConfig; using XDataType = typename TypeConfig::XDataType; using YDataType = typename TypeConfig::YDataType; + using XBiasDataType = typename TypeConfig::XBiasDataType; using GammaDataType = typename TypeConfig::GammaDataType; using BetaDataType = typename TypeConfig::BetaDataType; using XResidualDataType = XDataType; @@ -121,6 +135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // host verify ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor x_bias_host({n}); ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); @@ -135,30 +150,33 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor y_scale_host_ref({m}); ck_tile::HostTensor y_scale_host_dev({m}); - ck_tile::HostTensor x_scale_host({n}); - ck_tile::HostTensor x_scale_host_dev({n}); + ck_tile::HostTensor sm_scale_host({n}); + ck_tile::HostTensor sm_scale_host_dev({n}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(x_scale_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_bias_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); + x_bias_buf.ToDevice(x_bias_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); x_residual_buf.ToDevice(x_residual_host.data()); - x_scale_buf.ToDevice(x_scale_host.data()); + sm_scale_buf.ToDevice(sm_scale_host.data()); auto prec_str = [&]() { auto base_str = prec_i; @@ -179,11 +197,12 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", yr_stride:" << yr_stride << std::flush; layernorm2d_fwd_traits traits{ - prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; + prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, - fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr, + x_bias_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), @@ -210,8 +229,9 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + - sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n + + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + + sizeof(YDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; @@ -221,6 +241,22 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference + if(xbias != 0) + { + // add bias before fadd + int M = x_host.mDesc.get_lengths()[0]; + int N = x_host.mDesc.get_lengths()[1]; + for(int idx_m = 0; idx_m < M; ++idx_m) + { + for(int idx_n = 0; idx_n < N; ++idx_n) + { + x_host(idx_m, idx_n) = ck_tile::type_convert( + ck_tile::type_convert(x_host(idx_m, idx_n)) + + ck_tile::type_convert(x_bias_host(idx_n))); + } + } + } + if(fused_add != 0) { // fused pre_add/pre_add_store @@ -254,8 +290,8 @@ bool run(const ck_tile::ArgParser& arg_parser) for(int n_ = 0; n_ < N_; n_++) { // input smooth outlier - acc_(m_, n_) = - acc_(m_, n_) * ck_tile::type_convert(x_scale_host(n_)); + acc_(m_, n_) = acc_(m_, n_) * + ck_tile::type_convert(sm_scale_host(n_)); } } ComputeDataType absmax = static_cast(0); @@ -265,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser) absmax = a > absmax ? a : absmax; } // printf("cpu:absmax:%f\n", absmax); - ComputeDataType y_scale = absmax / static_cast(127.0); + constexpr ComputeDataType kMaxY = + std::is_same::value ? 240.0 + : std::is_same::value ? 127.0 + : 0.0; + ComputeDataType y_scale = absmax / kMaxY; y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); for(int n_ = 0; n_ < N_; n_++) { @@ -308,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } - auto [rtol, atol] = get_elimit(); + auto [rtol, atol] = get_elimit(); if(x_stride == n) { @@ -377,16 +417,16 @@ int main(int argc, char* argv[]) std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_sx = arg_parser.get_str("prec_sx"); + std::string prec_sm = arg_parser.get_str("prec_sm"); std::string prec_sy = arg_parser.get_str("prec_sy"); if(prec_o == "auto") { prec_o = prec_i; } - if(prec_sx == "auto") + if(prec_sm == "auto") { - prec_sx = "fp32"; + prec_sm = "fp32"; } if(prec_sy == "auto") { @@ -395,37 +435,47 @@ int main(int argc, char* argv[]) int save_mv = arg_parser.get_int("save_mv"); // no dynamic quant case - if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && save_mv) + if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } // dynamic quant case, only in inference - else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } - else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32" && + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && !save_mv) { return run(arg_parser) ? 0 : -2; } + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } return -3; } diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index a0f2db0e8a478da5a4302fe7439aa1354d3b923a..0538953a580e76322920c3bd1cdb3db78971b591 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,35 +8,40 @@ #include "ck_tile/ops/layernorm2d.hpp" #include -template +template struct LayerNormTypeConfig; -template -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { - using XDataType = ck_tile::half_t; - using YDataType = OutType; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; - using ComputeDataType = float; - using XScaleDataType = XScaleDataType_; - using YScaleDataType = YScaleDataType_; + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template -struct LayerNormTypeConfig +template +struct LayerNormTypeConfig { - using XDataType = ck_tile::bf16_t; - using YDataType = OutType; - using GammaDataType = ck_tile::bf16_t; - using BetaDataType = ck_tile::bf16_t; - using MeanDataType = ck_tile::bf16_t; - using InvStdDataType = ck_tile::bf16_t; - using ComputeDataType = float; - using XScaleDataType = XScaleDataType_; - using YScaleDataType = YScaleDataType_; + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using BetaDataType = ck_tile::bf16_t; + using MeanDataType = ck_tile::bf16_t; + using InvStdDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -50,13 +55,14 @@ struct layernorm2d_fwd_traits std::string prec_i; // input precision std::string prec_o; // output precision - // if fused_quant == 1, need set prec_sx/prec_sy to proper string, otherwise can set + // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise // can set arbitrary(will skip check) - std::string prec_sx; // x-scale, used for [1*N] input smooth quant + std::string prec_sm; // x-scale, used for [1*N] input smooth quant std::string prec_sy; // y-scale, used for [M*1] output for next layer bool save_mean_var; // + int xbias; // 0:no-bias, 1:add bias int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index b7fd354bb8e05647dc66ee9f4699757bc8378d8b..ceaf262bd9eb79308d697f0493cdc8536f4f227a 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -1,7 +1,7 @@ #!/bin/sh EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8"; do +for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do for pr_i in "fp16" "bf16" ; do for fadd in "0" "1"; do $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 @@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 done done diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index d166eed458fd37394bfd8111bc65c9396880b932..30cfee22f63c5c892463ea3eb2232ce4489aa62d 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,5 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +target_compile_options(tile_example_gemm_universal PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index e9ffe72a9152d22b62680369ac03ab569c0eecce..4c16f13cefdcfa474c534ce16fc8e50d742999eb 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation -make tile_example_gemm_mem_pipeline -j +make tile_example_gemm_universal -j ``` -This will result in an executable `build/bin/tile_example_gemm_basic` +This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` ## example ``` @@ -22,6 +22,9 @@ args: -m m dimension (default:1024) -n n dimension (default:2048) -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 4c630375f4420de8c2ad5819d2e581206fac5006..5dc7b9cd0b9310094e05eaf756899693d3722dbb 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,12 +9,16 @@ #include #include -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -template +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -22,16 +26,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; - constexpr int kBlockPerCu = 1; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -39,42 +39,33 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; + constexpr ck_tile::index_t K_Warp_Tile = 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; @@ -91,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } @@ -105,4 +99,46 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& #include "run_gemm_example.inc" +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 58cdaea7d85f7dffb7e888fca9be3160b38484d2..636b34981fb1a4100d897afcd0f7ad60a64f3b26 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,32 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif template struct GemmBasicTypeConfig; @@ -22,6 +48,33 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -43,23 +96,32 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "3840", "m dimension") + arg_parser.insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") @@ -68,7 +130,9 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 68df389bfc71884b03d0eedb802da2478da2be6a..042ad372dc5914e9ff76ffcabbf1bbb061ceec3c 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -1,8 +1,42 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -template +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -28,8 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + float ave_time = + gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -39,13 +74,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -56,6 +94,11 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -64,56 +107,35 @@ int run_gemm_example_with_layouts(int argc, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - ck_tile::index_t batch_size = arg_parser.get_int("b"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - - using namespace ck_tile::literals; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, a_layout); - stride_B = f_get_default_stride(K, N, stride_B, b_layout); - stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); - - ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout)); - ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout)); - ck_tile::HostTensor c_m_n_dev_result( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); - // TODO: add different init types - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if (init_method == 0) { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } else if (init_method == 1) { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } else if (init_method == 2) { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } else { + a_m_k.SetZero(); + b_k_n.SetZero(); + } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); @@ -124,18 +146,19 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - batch_size, - n_warmup, - n_repeat); + invoke_gemm( + a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -143,20 +166,30 @@ int run_gemm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_host_ref); - - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); - - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { ck_tile::HostTensor c_m_n_gpu_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); @@ -196,46 +229,21 @@ int run_gemm_example_with_layouts(int argc, ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); - - std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } return pass; } - -int run_gemm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - - if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. - // else if(a_layout == "C" && b_layout == "C") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - // } - // else if(a_layout == "C" && b_layout == "R") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - // } - else - { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); - } -} diff --git a/example/ck_tile/03_gemm/script/benchmark_basic.sh b/example/ck_tile/03_gemm/script/benchmark_basic.sh new file mode 100755 index 0000000000000000000000000000000000000000..64d2ddbb5cd9085de7a2373005afc95472f3f6bb --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh new file mode 100644 index 0000000000000000000000000000000000000000..21462616be3681bc696baef0d4554dbf77b64dac --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh new file mode 100755 index 0000000000000000000000000000000000000000..c4cf4ddcbfba8815f19456f857827fdfdaba1ce6 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh new file mode 100644 index 0000000000000000000000000000000000000000..903b4a3c0ff385408cbb47aff660f5c57c2b802f --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh new file mode 100644 index 0000000000000000000000000000000000000000..8c92c2e99116047506b5f417433afc1ca11e2d0a --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh new file mode 100644 index 0000000000000000000000000000000000000000..e238006c7d0e0cf15c63dc0cfb456415d77f24d1 --- /dev/null +++ b/example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh @@ -0,0 +1,13 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +VALID=1 + +for b_matrix_layout in "C"; do + for m in "512" "1024" "2048" "4096"; do + for n in "512" "1024" "2048"; do + for k in "512" "1024" "2048"; do + $EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh index 2e2e7fdf90524bfdaedb240b4bc13c6556be9225..45bd1bed614f9e4c0f4583ff3f91bc129f5dd545 100755 --- a/example/ck_tile/03_gemm/script/run_full_test.sh +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -19,7 +19,27 @@ echo 'Host name: ' $host_name export GPU_arch=$4 echo 'GPU_arch: ' $GPU_arch +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + # get GPU architecture and compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + # run verification tests -example/ck_tile/03_gemm/script/smoke_test.sh +example/ck_tile/03_gemm/script/smoke_test_basic.sh +example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh + +# run performance benchmarks +export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log" +print_log_header $gemm_basic_log $env_type $branch $host_name +example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log -# We do not have a performance benchmark for gemm yet. Will add it in the future. \ No newline at end of file +export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log" +print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name +example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log diff --git a/example/ck_tile/03_gemm/script/smoke_test.sh b/example/ck_tile/03_gemm/script/smoke_test.sh deleted file mode 100755 index 4d9a64bf40dc38f19dd7612ce1dfb9cd14e5ab45..0000000000000000000000000000000000000000 --- a/example/ck_tile/03_gemm/script/smoke_test.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" -KNAME=1 - -export CK_WARMUP=0 -export CK_REPEAT=1 - -COMMON_ARGS='-v=2 -warmup=0 -repeat=1' - -run_fp16_tests() { - for batch in 1 2; do - for m in 128 1024; do - for n in 128 2048; do - for k in 32 64; do - - $EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS - if [ $? -eq 0 ]; then - echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." - else - echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." - # Optionally, exit or break if you need to halt further execution - # exit 1 - fi - - done - done - done - done -} - -set -x - -run_fp16_tests - -set +x \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/smoke_test_basic.sh b/example/ck_tile/03_gemm/script/smoke_test_basic.sh new file mode 100755 index 0000000000000000000000000000000000000000..7ca6759f420bc0f0af3868ac2f25644f2e6fd689 --- /dev/null +++ b/example/ck_tile/03_gemm/script/smoke_test_basic.sh @@ -0,0 +1,36 @@ +#!/bin/bash +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 64 128; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" + +set +x diff --git a/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh new file mode 100755 index 0000000000000000000000000000000000000000..951f8aa63ae5bf9c6aeb97106baa8af644eb8c63 --- /dev/null +++ b/example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh @@ -0,0 +1,36 @@ +#!/bin/bash +EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 512 1024; do + for n in 512 2048; do + for k in 512 1024; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "fp16" +run_tests "bf16" +run_tests "fp8" +run_tests "bf8" + +set +x diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 6c87ca0087e1610977a8c9d06242ba611ca17c44..d1b79177f50cdb50d8018955ceac50c2dd1f7faa 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,20 +9,17 @@ #include #include -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -#define CK_TILE_PIPELINE_COMPUTE 1 -#define CK_TILE_PIPELINE_MEMORY 2 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE -#endif - -template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +template +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -38,10 +35,28 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; @@ -51,13 +66,19 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; #endif constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - constexpr int kBlockPerCu = 1; + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; // =============================================== @@ -65,20 +86,26 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< -#endif - ck_tile::GemmPipelineProblem>; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -87,36 +114,35 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< -#endif - ck_tile::UniversalGemmPipelineProblem>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -139,6 +165,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) if(has_hot_loop) { +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { @@ -199,23 +240,26 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ck_tile::integral_constant{}); } } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } +#endif + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } return ave_time; @@ -223,4 +267,115 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) #include "run_gemm_example.inc" +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "C" && b_layout == "C") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else if(a_layout == "C" && b_layout == "R") + { + if(data_type == "fp16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf16") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "fp8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(data_type == "bf8") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index 005541dc62bc428df3011df44234754ee663f9b2..602661f7791a69192a8df3b9a63f5eb48eca50da 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // using WarpTile = ck_tile::sequence<1, 512>; // using Vector = ck_tile::sequence<1, 8>; - constexpr ck_tile::index_t kBlockSize = 512; + constexpr ck_tile::index_t kBlockSize = 256; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); std::cout << "grid size " << kGridSize << std::endl; diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt index a3ff8fdf4595715a1620e3a8f46de870ecd06bb1..5684c9b2e00f30248e39e2a0f93d08a9cb2e33bd 100644 --- a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -1,16 +1,39 @@ +set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" + message("adding ${TILE_RMSNORM2D_FWD}") -file(GLOB INSTANCE_SRCS instances/*.cpp) add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) +target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 34df7b74fa3c710becc10670b54c76f910317781..48c150009e291b6baab28066d2185fc62487bc94 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -1,6 +1,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/rmsnorm2d.hpp" #include @@ -36,10 +37,12 @@ bool run(const ck_tile::ArgParser& arg_parser) assert(stride >= n); - using XDataType = DataType; - using YDataType = DataType; - using GammaDataType = DataType; - using InvRmsDataType = ck_tile::null_type; + using XDataType = DataType; + using YDataType = DataType; + using GammaDataType = DataType; + using InvRmsDataType = ck_tile::null_type; + using SmoothScaleDataType = ck_tile::null_type; + using YScaleDataType = ck_tile::null_type; using ComputeDataType = float; @@ -68,30 +71,49 @@ bool run(const ck_tile::ArgParser& arg_parser) using BlockTile = ck_tile::sequence<2, 128>; using WarpTile = ck_tile::sequence<1, 64>; using Vector = ck_tile::sequence<1, 1>; + using Shape = ck_tile::Generic2dBlockShape; + + using PipelineTraits = + ck_tile::Rmsnorm2dFwdTraits; // fuse quant - using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; + PipelineTraits>; using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; using Pipeline = std::conditional_t; - using Kernel = ck_tile::Rmsnorm2dFwd; + + using Default2DEpilogueProblem = ck_tile:: + Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + using Kernel = ck_tile::Rmsnorm2dFwd; ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), + nullptr, + nullptr, gamma_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(), nullptr, + nullptr, + nullptr, epsilon, m, n, + stride, + stride, + stride, stride}; auto kargs = Kernel::MakeKargs(args); diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..dadb2268b2e50639a603f4f17915e6508eca1186 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -0,0 +1,683 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'sdquant', # smooth dynamic quant + 'dquant' ] # dynamic quant (without sm_scale) + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + + +class rmsnorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct rmsnorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using SmoothScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = rmsnorm2d_fwd_traits_; +""" + + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = rmsnorm2d_fwd_args; + +{F_traits_define} + +template +float rmsnorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename RmsnormTypeConfig::ComputeDataType; + + using PipelineTraits = + ck_tile::Rmsnorm2dFwdTraits(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + + using PipelineProblem = + ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, + typename RmsnormTypeConfig::GammaDataType, + typename RmsnormTypeConfig::ComputeDataType, + typename RmsnormTypeConfig::YDataType, + typename RmsnormTypeConfig::InvRmsDataType, + typename RmsnormTypeConfig::SmoothScaleDataType, + typename RmsnormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Rmsnorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, + rmsnorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_api_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +{F_instance_def} +// clang-format on + +""" + + API_PER_DTYPE = """ + {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE = """ + {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE = """ + {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kFusedAdd : Any + F_kFusedQuant : Any + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_GammaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_InvRmsDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_SmoothScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveInvRms : bool + F_kTwoPass : bool + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'rmsnorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float rmsnorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'rmsnorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'rmsnorm2d_fwd_api_common' + + @property + def content_api(self) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs() + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self): + h_traits = rmsnorm_fwd_codegen.h_traits + h_instance = rmsnorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8', 'fp8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8'), + ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + fused_add_list = [0, 1] + fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + + # rm rn tm tn vn pd mv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_sm, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: + continue # skip non dynamic quant case + if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'rmsnorm2d_fwd_blobs.txt' + blobs = self.get_blobs() + with list_p.open('w') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self) -> None: + w_p = Path(self.working_path) + (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs() + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs() + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK rmsnorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp deleted file mode 100644 index b8697183f96bc6b4421ecbd8b26353ae7c00941e..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "rmsnorm2d_fwd.hpp" - -template -using trait_ = rmsnorm2d_fwd_traits_; - -template -float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/, - rmsnorm2d_fwd_args a, - const ck_tile::stream_config& s) -{ - float r = -1; - // clang-format off - // rm rn tm tn vn pd rms 2p - if(a.n <= 64) { - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 128) { - if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 256) { - if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 512) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 768) { - if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 1024) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 1536) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 2048) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 3072) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n <= 4096) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - else if(a.n > 4096) { - if (a.n % 8 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 4 == 0) - r = rmsnorm2d_fwd_>(s, a); - else if (a.n % 2 == 0) - r = rmsnorm2d_fwd_>(s, a); - else - r = rmsnorm2d_fwd_>(s, a); - } - return r; - // clang-format on -} - -float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) -{ - - if(t.data_type.compare("fp16") == 0) - { - return rmsnorm2d_fwd_b16_(t, a, s); - } - else if(t.data_type.compare("bf16") == 0) - { - return rmsnorm2d_fwd_b16_(t, a, s); - } - else - throw std::runtime_error("Without supported instances!"); -} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp deleted file mode 100644 index 5e2a35f9e8fb496f21832c222b9f68042a63a21d..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -#if 0 -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -template float rmsnorm2d_fwd_>(const S&, A); -#endif - -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp deleted file mode 100644 index 8c734806e18b4782092f7a0e5cc460b3abc158d4..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp deleted file mode 100644 index 9222001433464eebcf1e20911b6b06b85c117270..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp deleted file mode 100644 index ed33c849232cc95d251240f7d146678273bd4e52..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp deleted file mode 100644 index b753bbc3458d3194f0cc6962b51d499bd331848b..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp deleted file mode 100644 index 27cb9bdf3d47dc34909e0c1333c5daa32e640274..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index 23afb5672b4b109fa9d2b89abec5318766540e92..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp deleted file mode 100644 index b428f58051bae64bd0497a991a7fc265ad96ec3d..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp deleted file mode 100644 index 3001106697dafa0d521672af936e17b1cf2fddac..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp deleted file mode 100644 index e9c8d6a1d444b0b085bbfc20575352a1d766b2a3..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp deleted file mode 100644 index 15198eebe67258266529ed81a7c8f5bf16d48ca2..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp +++ /dev/null @@ -1,22 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -#if 0 -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -template float rmsnorm2d_fwd_>(const S&, A); -#endif - -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp deleted file mode 100644 index 8ac85fa9b5a68246a5af7c039b4131a3b35c9c56..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp deleted file mode 100644 index 10e8fafc2f4c780ff622ee501f685671fc7dd25a..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp deleted file mode 100644 index 4e1a80bf64b3598864765454fd46f8ce9c9c6eb0..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp deleted file mode 100644 index 45e56a92b8886ffe3b07189646aad12caeffb359..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp deleted file mode 100644 index 35401f6f82b50c40137599456250c13c46092cb2..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index 1e3700fad3ab61ff669d2950f96578170a01320e..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp deleted file mode 100644 index cdc4d00bd2336a1e55c900cd078dd8cde52ac11b..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp +++ /dev/null @@ -1,13 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp deleted file mode 100644 index ec80c2ee4a93f999be3960ba7154a16d8992f302..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp deleted file mode 100644 index ddfc5a54e8e6ea3129804f28a56aed98b5432f67..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp +++ /dev/null @@ -1,12 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "rmsnorm2d_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd rms 2p -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -template float rmsnorm2d_fwd_>(const S&, A); -// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp deleted file mode 100644 index 8f6ff84b643d2b7fafebc5b0a9ef6ade1ebdbd23..0000000000000000000000000000000000000000 --- a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp +++ /dev/null @@ -1,65 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include "rmsnorm2d_fwd.hpp" -#include - -#pragma once - -using S = ck_tile::stream_config; -using A = rmsnorm2d_fwd_args; - -template -using trait_ = rmsnorm2d_fwd_traits_; - -template -float rmsnorm2d_fwd_(const S& s, A a) -{ - using DataType = typename Traits_::DataType; - - using PipelineProblem = - ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, - typename RmsnormTypeConfig::GammaDataType, - typename RmsnormTypeConfig::ComputeDataType, - typename RmsnormTypeConfig::YDataType, - typename RmsnormTypeConfig::InvRmsDataType, - typename Traits_::Shape, - Traits_::kPadN, - Traits_::kSaveInvRms, - Traits_::kTwoPass>; - - using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; - using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; - using Pipeline = std::conditional_t; - - using Kernel = ck_tile::Rmsnorm2dFwd; - - const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; - - auto kargs = Kernel::MakeKargs(a); - if(s.log_level_ > 0) - std::cout << ", " << Kernel::GetName() << std::flush; - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); -} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 698a8b43eb9329f5bfb0c61b78cea98a0cf07f5f..cdee6dfb80041ef3afa8e5545313b1bab472399b 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -19,17 +19,37 @@ auto get_elimit() return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit() +{ + double rtol = 1e-02; + double atol = 1.0; + return ck_tile::make_tuple(rtol, atol); +} + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") - .insert("prec", "fp16", "precision") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sm", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -37,28 +57,70 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - - assert(stride >= n); - - using TypeConfig = RmsnormTypeConfig; - - using XDataType = typename TypeConfig::XDataType; - using YDataType = typename TypeConfig::YDataType; - using GammaDataType = typename TypeConfig::GammaDataType; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; + assert(x_stride >= n); + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8") + { + std::cout + << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases." + << std::endl; + return false; + } + + using TypeConfig = + RmsnormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; using InvRmsDataType = std::conditional_t; @@ -66,43 +128,84 @@ bool run(const ck_tile::ArgParser& arg_parser) using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor gamma_host({n}); + ck_tile::HostTensor sm_scale_host({n}); + ck_tile::HostTensor sm_scale_host_dev({n}); - ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); + + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); ck_tile::HostTensor invRms_host_ref({m}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + sm_scale_buf.ToDevice(sm_scale_host.data()); + + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << prec_str << "]" + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; - rmsnorm2d_fwd_traits traits{data_type, SaveRms}; + rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr, gamma_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(), - nullptr, + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_invRms, unsupported yet epsilon, m, n, - stride}; + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride float ave_time = rmsnorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; + num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0; + num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0; + num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0; + num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; @@ -112,38 +215,135 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference - ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + [](auto x_, auto r_) { + auto o_ = ck_tile::type_convert(x_) + + ck_tile::type_convert(r_); + return ck_tile::type_convert(o_); + }); + } + + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = acc_(m_, n_) * + ck_tile::type_convert(sm_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + constexpr ComputeDataType kMaxY = + std::is_same::value ? 240.0 + : std::is_same::value ? 127.0 + : 0.0; + ComputeDataType y_scale = absmax / kMaxY; + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor); + } + else + { + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + } y_buf.FromDevice(y_host_dev.data()); - auto [rtol, atol] = get_elimit(); - if(stride == n) + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(y_residual_host_dev.data()); + } + + auto [rtol, atol] = get_elimit(); + if(x_stride == n) { pass = ck_tile::check_err( - y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol); + + if(fused_add == 1) + { + pass &= ck_tile::check_err(y_residual_host_dev, + x_host, + std::string("\nADD Error: Incorrect results!"), + rtol, + atol); + } } else { for(int i_r = 0; i_r < m; i_r++) { - std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, - y_host_dev.begin() + i_r * stride + n); - std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, - y_host_ref.begin() + i_r * stride + n); + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); pass &= ck_tile::check_err(y_host_dev_row, y_host_ref_row, - std::string("OUT[") + std::to_string(i_r) + + std::string("\nOUT[") + std::to_string(i_r) + std::string("] Error: Incorrect results!"), rtol, atol); + + if(fused_add == 1) + { + std::vector y_residual_host_dev_row( + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); + std::vector y_residual_host_ref_row( + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); + pass &= ck_tile::check_err(y_residual_host_dev_row, + y_residual_host_ref_row, + std::string("\nADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("\nSCALE Error: Incorrect results!"), + rtol, + atol); + } + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } @@ -156,23 +356,65 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_rms = arg_parser.get_int("save_rms"); - if(data_type == "fp16" && save_rms) + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + + int save_rms = arg_parser.get_int("save_rms"); + + if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "fp16" && !save_rms) + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && save_rms) + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && !save_rms) + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index b4d429d46f4a1a418527bf515ccd5e06e6243352..566b94442d4c934cc63375e18a5a93eb6b55841c 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,27 +8,34 @@ #include "ck_tile/ops/rmsnorm2d.hpp" #include -template +template struct RmsnormTypeConfig; -template <> -struct RmsnormTypeConfig +template +struct RmsnormTypeConfig { - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using InvRmsDataType = ck_tile::half_t; - using ComputeDataType = float; + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using GammaDataType = ck_tile::half_t; + using InvRmsDataType = ck_tile::half_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; -template <> -struct RmsnormTypeConfig +template +struct RmsnormTypeConfig { - using XDataType = ck_tile::bf16_t; - using YDataType = ck_tile::bf16_t; - using GammaDataType = ck_tile::bf16_t; - using InvRmsDataType = ck_tile::bf16_t; - using ComputeDataType = float; + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using GammaDataType = ck_tile::bf16_t; + using InvRmsDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; }; // runtime args @@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs { }; -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct rmsnorm2d_fwd_traits_ -{ - using DataType = ck_tile::remove_cvref_t; - - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); - static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; - - // num of warps along m - static constexpr ck_tile::index_t BlockWarps_M = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); - } - else - { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); - } - }(); - - // num of warps along n - static constexpr ck_tile::index_t BlockWarps_N = []() { - if constexpr(is_warp_per_row) - { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return 1; - } - else - { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; - } - }(); - - static constexpr ck_tile::index_t Repeat_M = Repeat_M_; - static constexpr ck_tile::index_t Repeat_N = Repeat_N_; - - static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; - static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; - - static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; - static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; - - using BlockTile = ck_tile::sequence; - using BlockWarps = ck_tile::sequence; - using WarpTile = ck_tile::sequence; - using Vector = ck_tile::sequence<1, Vector_N_>; - - using Shape = ck_tile::Generic2dBlockShape; - - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kTwoPass = kTwoPass_; -}; - template float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); // This is the public API, will be generated by script struct rmsnorm2d_fwd_traits { - std::string data_type; + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sm; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + bool save_rms; + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 758d6de54680cc303068044c5e0fc8d27baba4b2..ab890738b31aff69169d34d6af3214d7f2b63dbc 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,30 +1,34 @@ #!/bin/sh EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" +for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do for pr_i in "fp16" "bf16" ; do -$EXE -prec=$pr_i -m=99 -n=13 -$EXE -prec=$pr_i -m=17 -n=16 -$EXE -prec=$pr_i -m=1 -n=100 -$EXE -prec=$pr_i -m=4 -n=128 -$EXE -prec=$pr_i -m=80 -n=127 -$EXE -prec=$pr_i -m=22 -n=255 -stride=256 -$EXE -prec=$pr_i -m=7 -n=599 -$EXE -prec=$pr_i -m=19 -n=512 -$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 -$EXE -prec=$pr_i -m=11 -n=510 -$EXE -prec=$pr_i -m=171 -n=676 -stride=818 -$EXE -prec=$pr_i -m=91 -n=636 -$EXE -prec=$pr_i -m=12 -n=768 -stride=800 -$EXE -prec=$pr_i -m=100 -n=766 -stride=812 -$EXE -prec=$pr_i -m=31 -n=1024 -$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 -$EXE -prec=$pr_i -m=8 -n=1501 -$EXE -prec=$pr_i -m=3 -n=1826 -$EXE -prec=$pr_i -m=5 -n=2040 -$EXE -prec=$pr_i -m=7 -n=2734 -$EXE -prec=$pr_i -m=1 -n=3182 -$EXE -prec=$pr_i -m=9 -n=4096 -$EXE -prec=$pr_i -m=3 -n=8192 -$EXE -prec=$pr_i -m=1 -n=10547 -$EXE -prec=$pr_i -m=3 -n=17134 +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done done diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index aa1d1adfd1e8d4986f7b5cc269d6bdccb1347567..20e1591516f4b614f023d2c53ac6f5fd8c3d5110 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser) int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - assert(stride >= n); + assert(x_stride >= n); - using XDataType = DataType; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = DataType; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; // host verify ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); - ck_tile::HostTensor xscale_host({n}); + ck_tile::HostTensor smscale_host({n}); ck_tile::HostTensor yscale_host_ref({m}, {1}); ck_tile::HostTensor yscale_host_dev({m}, {1}); @@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - xscale_buf.ToDevice(xscale_host.data()); + smscale_buf.ToDevice(smscale_host.data()); constexpr bool kTwoPass = true; @@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::SmoothquantPipelineProblem; ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), - xscale_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), m, @@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser) // smooth outlier { auto f = [&](auto n_) { - auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + auto v_smscale = ck_tile::type_convert(smscale_host(n_)); for(int m_ = 0; m_ < m; ++m_) { auto v_x = ck_tile::type_convert(x_host(m_, n_)); - y_host(m_, n_) = v_x * v_xscale; + y_host(m_, n_) = v_x * v_smscale; } }; - ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())( std::thread::hardware_concurrency()); } diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp index cdf93f6fcfd3d523c7723e28e006d089bf49f350..555159566eed1fb2e5b341bb9941aeba8cc9a1a4 100644 --- a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "smoothquant.hpp" @@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) using PipelineProblem = ck_tile::SmoothquantPipelineProblem< typename SmoothquantTypeConfig::XDataType, - typename SmoothquantTypeConfig::XScaleDataType, + typename SmoothquantTypeConfig::SmoothScaleDataType, typename SmoothquantTypeConfig::ComputeDataType, typename SmoothquantTypeConfig::YScaleDataType, typename SmoothquantTypeConfig::QYDataType, diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index fd1c4ec7b472564cf53e8f76ce92211362108741..f3ba587132fe4d8b7cb3d08f154c237b2ed53206 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = SmoothquantTypeConfig; - using XDataType = typename TypeConfig::XDataType; - using XScaleDataType = typename TypeConfig::XScaleDataType; - using YScaleDataType = typename TypeConfig::YScaleDataType; - using QYDataType = typename TypeConfig::QYDataType; - using ComputeDataType = typename TypeConfig::ComputeDataType; + using XDataType = typename TypeConfig::XDataType; + using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); - ck_tile::HostTensor xscale_host({n}); + ck_tile::HostTensor smscale_host({n}); ck_tile::HostTensor yscale_host_ref({m}, {1}); ck_tile::HostTensor yscale_host_dev({m}, {1}); @@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - xscale_buf.ToDevice(xscale_host.data()); + smscale_buf.ToDevice(smscale_host.data()); std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride @@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser) smoothquant_traits traits{data_type}; smoothquant_args args{x_buf.GetDeviceBuffer(), - xscale_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), m, @@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = smoothquant( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); - std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser) // smooth outlier { auto f = [&](auto n_) { - auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + auto v_smscale = ck_tile::type_convert(smscale_host(n_)); for(int m_ = 0; m_ < m; ++m_) { auto v_x = ck_tile::type_convert(x_host(m_, n_)); - y_host(m_, n_) = v_x * v_xscale; + y_host(m_, n_) = v_x * v_smscale; } }; - ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())( std::thread::hardware_concurrency()); } diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp index 26a598db55bc19c5ce9e1035eeef2add79fd1f35..83ad7b012ca4fa40c75bd2ef9786eeff5a8bdf6d 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.hpp +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,21 +14,21 @@ struct SmoothquantTypeConfig; template <> struct SmoothquantTypeConfig { - using XDataType = ck_tile::half_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = ck_tile::half_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; }; template <> struct SmoothquantTypeConfig { - using XDataType = ck_tile::bf16_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = ck_tile::bf16_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; }; // runtime args diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d2c4df105838d6ae1ebe3ef5058330a74fd55585..c4faa35e33007e0622e43560375b95d69499d581 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) .insert("k", "4", "topk") .insert("unit", "32", "unit_size") .insert("moe_buf_size", "0", "moe_buf_size") + .insert("local_eid", + "-1", + "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" + "please make sure eid is in ascending order!") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") @@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) return false; } + bool local_expert_masking = args.get_str("local_eid") != "-1"; + auto local_expert_masking_host = [&]() { + if(local_expert_masking) + { + auto local_eid = args.get_int_vec("local_eid"); + // std::vector v_ {num_experts, 0}; + ck_tile::HostTensor v_{{num_experts}}; + v_.SetZero(); + for(auto eid : local_eid) + { + if(eid >= num_experts) + { + throw std::runtime_error( + "local_eid larger than number of expert, please check"); + } + v_.mData[eid] = 1; + } + return v_; + } + else + // return std::vector{}; + return ck_tile::HostTensor{{1}}; + }(); + // tokens already considered batch size ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); @@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem local_expert_masking_dev( + local_expert_masking_host.get_element_space_size_in_bytes()); topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); @@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) { moe_buf_dev.ToDevice(moe_buf_host.data()); } + if(local_expert_masking) + local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); - moe_sorting_trait trait{index_prec, weight_prec}; + moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), @@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) warmup, repeat}; auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk, - ms); + topk); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + if(ms < 0) printf("not supported\n"); + else + printf("ms:%f, ", ms); fflush(stdout); if(ms < 0) { @@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) int32_t ref_total_tokens_post_pad = 0; ck_tile::reference_moe_sorting(topk_ids_host, weights_host, + local_expert_masking_host, sorted_ids_ref, sorted_weights_ref, sorted_expert_ids_ref, ref_total_tokens_post_pad, num_experts, - unit_size); + unit_size, + local_expert_masking); rtn &= ck_tile::check_err( sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); rtn &= ck_tile::check_err(sorted_weights_host, @@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("valid:%s", rtn ? "y" : "n"); + fflush(stdout); + if(!rtn) + printf(", (%d)", seed); + printf("\n"); fflush(stdout); return rtn; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 723fb3f69f1e70877d053b59cdb3ea25864089c3..abff24a66975d5593ba0d01fbd972f35c2586309 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,6 +3,12 @@ #include "moe_sorting_api.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,67 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +105,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + (void)c_; + + MOE_SORTING_DISPATCH_EMASK_(r_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0cb393f7dedd450d03fdadcaac1c929791ff4ef8..5bda4d368a4ca394bb66d3a22edc972f47788264 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,7 +10,8 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 3ff8a7332daa45d8882430a50e20c7ed86a9454a..cf2c2e164b75df25e478b01aed46478af69bb6df 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 $EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp index f43626147fd19f52db67e8b19184bd7fe651968f..39481e2c83c8129d0b8a21de3f9183e97bd932aa 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1024_instance.cpp @@ -15,8 +15,13 @@ template float moe_smoothquant_>(const S&, A); #endif -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp index e380520fce79f2a32ed12410c595dc3671ee4b8a..6feccbdaff3959fc246bb88cfaf2669bd6f11398 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n1536_instance.cpp @@ -6,8 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp index 4d536cd61d8b2bec48882bdbb5ea36053c317c58..0e2c9366338ed53c07f134d08437b5618a3fd058 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n2048_instance.cpp @@ -6,9 +6,14 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp index b38a4733a402dc0fab820b0b1422b2651446a3bd..373cb0352b6b13d7a9089c8099a36c640656701b 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n256_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp index c5c170aef1beb4674a3730d4f90a6cf111885129..c0c778f36c98d102b3a15bd0d9c14380892626df 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n3072_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp index 0e48a1b69153ab797a3887fc9302540fe31a2fcb..47cffd5fc2f067e176ed77e5f59d3f7eef113893 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp index 4af42c6c804bd0d8235ca76d6beb8d938ead76b2..726d6018a6bf99073a878ffe702f1a20f765b9eb 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n4096_tp_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp index ea611a183469cba4bac9ab6ce02639e54d87ade3..6026d509d0c28825a82b571deb68184f2d6501ac 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n512_instance.cpp @@ -6,8 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp index a6209820e60a33313444439f7fb210a5e0497335..3924662fe530c27b89e173edd282939b0531d0bd 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n64_n128_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp index f569dedf35370d377f4c5d346942d07cf77647c3..00d5c980d7ab0eed32d4c8582c5c09a6b74b8ac6 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_bf16_n768_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp index 3793adb5c5a16102956a68132b04e66fd29a93f0..c908739efa57f00848e9f14e402db130b329d89b 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1024_instance.cpp @@ -15,8 +15,13 @@ template float moe_smoothquant_>(const S&, A); #endif -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp index 4bf9cb1a49a4630a54c6ca6c3a442232571696ba..65e9470cdeb62723c8e2114691947e9b924ab078 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n1536_instance.cpp @@ -6,8 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp index eb0d0fe103a3c8af5fa44bb930d90fc5cb8e33dc..421352f45ffbc3fa30ac92b1347d928417150138 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n2048_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp index 36bc0de15048a371b0d9243f1210abf6e672599b..f102cb6d60c60a46e94bc423394aa6f86b9c35e8 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n256_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp index fa6f53b2d4a6aeaf25427232ee74f6438d1dd843..ad7b9e3d158641b2b91458c3aea32e1479ae4999 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp index 9b7462ab92f8d2c8f0907e759684b2c530f6335c..bb79ec7ab4222c07cc0f27c08b756a90de7b95d4 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp index 8911bc22958334e30a02bf281730286b4821aef6..766c60689f7884341c56d4336f56722d9012548d 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp index 07783ac168e514c743a9830d6cdd54820276d360..6c24e1ebe014d51aa519853a2f6430da7327d1ce 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp @@ -6,8 +6,13 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp index a5ab56a76c328f89a9f5acac471d798b6cdab88a..df785eefeff22b0fcbf2208c36c7ea5f592d5484 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp index 4272cbafc69a54724ecbd4057914eefe3b6e5f8f..d89f1c3bbf625da28c12f9b8f28c22a1fa47d4e4 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp @@ -6,7 +6,11 @@ // clang-format off // rm rn tm tn vn pd 2p -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); -template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); + +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); +template float moe_smoothquant_>(const S&, A); // clang-format on diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp index a65d3fde667d2e65b3083e233a987430115f8604..9d86c54b1ad7ad538c0da2de8816fa78ef705185 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp @@ -4,7 +4,8 @@ #include #include "moe_smoothquant.hpp" -template -using trait_ = moe_smoothquant_traits_; -template +template float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/, moe_smoothquant_args a, const ck_tile::stream_config& s) { float r = -1; // clang-format off - // rm rn tm tn vn pd 2p + // rm rn tm tn vn pd 2p if(a.hidden_size <= 64) { - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 128) { if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 256) { if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 512) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 768) { if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 1024) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 1536) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 2048) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 3072) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size <= 4096) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } else if(a.hidden_size > 4096) { if (a.hidden_size % 8 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 4 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else if (a.hidden_size % 2 == 0) - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); else - r = moe_smoothquant_>(s, a); + r = moe_smoothquant_>(s, a); } return r; // clang-format on @@ -132,13 +134,21 @@ float moe_smoothquant(moe_smoothquant_traits t, moe_smoothquant_args a, const ck_tile::stream_config& s) { - if(t.data_type.compare("fp16") == 0) + if(t.in_type.compare("fp16") == 0 && t.out_type == "int8") { - return moe_smoothquant_dispatch(t, a, s); + return moe_smoothquant_dispatch(t, a, s); } - else if(t.data_type.compare("bf16") == 0) + else if(t.in_type.compare("fp16") == 0 && t.out_type == "fp8") { - return moe_smoothquant_dispatch(t, a, s); + return moe_smoothquant_dispatch(t, a, s); + } + else if(t.in_type.compare("bf16") == 0 && t.out_type == "int8") + { + return moe_smoothquant_dispatch(t, a, s); + } + else if(t.in_type.compare("bf16") == 0 && t.out_type == "fp8") + { + return moe_smoothquant_dispatch(t, a, s); } else throw std::runtime_error("Without supported instances!"); diff --git a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index 88d3000910a01026191928772408c7d16f389092..885d9ff7bf5319d763f90983f78d73a46fca5bb2 100644 --- a/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "moe_smoothquant.hpp" @@ -11,7 +11,8 @@ using S = ck_tile::stream_config; using A = moe_smoothquant_args; -template -using trait_ = moe_smoothquant_traits_ float moe_smoothquant_(const S& s, A a) { - using DataType = typename Traits_::DataType; + using InputType = typename Traits_::InputType; + using OutputType = typename Traits_::OutputType; using PipelineProblem = ck_tile::SmoothquantPipelineProblem< - typename MoeSmoothquantTypeConfig::XDataType, - typename MoeSmoothquantTypeConfig::XScaleDataType, - typename MoeSmoothquantTypeConfig::ComputeDataType, - typename MoeSmoothquantTypeConfig::YScaleDataType, - typename MoeSmoothquantTypeConfig::QYDataType, + typename MoeSmoothquantTypeConfig::XDataType, + typename MoeSmoothquantTypeConfig::SmoothScaleDataType, + typename MoeSmoothquantTypeConfig::ComputeDataType, + typename MoeSmoothquantTypeConfig::YScaleDataType, + typename MoeSmoothquantTypeConfig::QYDataType, typename Traits_::Shape, Traits_::kPadN, Traits_::kTwoPass>; diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp index f1b374adbf87552791667aa2758bc3c447e7ac3f..dc5b397c854fbedaa38d5dbf8f395ab3ecbbb186 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -63,7 +63,8 @@ auto create_args(int argc, char* argv[]) .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") - .insert("prec", "fp16", "precision") + .insert("prec_i", "fp16", "input precision, fp16/bf16") + .insert("prec_o", "int8", "precision, int8/fp8") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -71,7 +72,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t tokens = arg_parser.get_int("t"); @@ -81,7 +82,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride = hidden_size; ck_tile::index_t experts = arg_parser.get_int("e"); ck_tile::index_t topk = arg_parser.get_int("k"); - std::string data_type = arg_parser.get_str("prec"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); int kname = arg_parser.get_int("kname"); int do_validation = arg_parser.get_int("v"); int warmup = arg_parser.get_int("warmup"); @@ -89,17 +91,17 @@ bool run(const ck_tile::ArgParser& arg_parser) assert(stride >= hidden_size); - using TypeConfig = MoeSmoothquantTypeConfig; + using TypeConfig = MoeSmoothquantTypeConfig; - using XDataType = typename TypeConfig::XDataType; - using XScaleDataType = typename TypeConfig::XScaleDataType; - using YScaleDataType = typename TypeConfig::YScaleDataType; - using QYDataType = typename TypeConfig::QYDataType; - using ComputeDataType = typename TypeConfig::ComputeDataType; + using XDataType = typename TypeConfig::XDataType; + using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify ck_tile::HostTensor x_host({tokens, hidden_size}, {stride, 1}); - ck_tile::HostTensor xscale_host({experts * hidden_size}); + ck_tile::HostTensor smscale_host({experts * hidden_size}); ck_tile::HostTensor topk_ids_host({tokens, topk}); ck_tile::HostTensor yscale_host_ref({topk * tokens}, {1}); @@ -110,26 +112,26 @@ bool run(const ck_tile::ArgParser& arg_parser) topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11937); ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); - xscale_buf.ToDevice(xscale_host.data()); + smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << data_type << "]" + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts << ", topk:" << topk << std::flush; - moe_smoothquant_traits traits{data_type}; + moe_smoothquant_traits traits{prec_i, prec_o}; moe_smoothquant_args args{x_buf.GetDeviceBuffer(), - xscale_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(), @@ -143,9 +145,10 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = moe_smoothquant( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); - std::size_t num_byte = - sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size + - sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size; + std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size + + sizeof(SmoothScaleDataType) * topk * hidden_size + + sizeof(YScaleDataType) * topk * tokens + + sizeof(QYDataType) * topk * tokens * hidden_size; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; @@ -165,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser) for(int i_h = 0; i_h < hidden_size; ++i_h) { - auto v_xscale = ck_tile::type_convert( - xscale_host(i_expert * hidden_size + i_h)); + auto v_smscale = ck_tile::type_convert( + smscale_host(i_expert * hidden_size + i_h)); auto v_x = ck_tile::type_convert(x_host(i_token, i_h)); - // y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale; - y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale; + // y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale; + y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale; } } }; @@ -250,14 +253,23 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + const std::string prec_i = arg_parser.get_str("prec_i"); + const std::string prec_o = arg_parser.get_str("prec_o"); + if(prec_i == "fp16" && prec_o == "int8") + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16") + else if(prec_i == "bf16" && prec_o == "fp8") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp index 9f9adda90fa74141c4178c6c775aa296fd5b9651..c1b90b14b2e45e618eff3d3ac9d310ab48dd6d3a 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,27 +8,14 @@ #include "ck_tile/ops/smoothquant.hpp" #include -template -struct MoeSmoothquantTypeConfig; - -template <> -struct MoeSmoothquantTypeConfig -{ - using XDataType = ck_tile::half_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; -}; - -template <> -struct MoeSmoothquantTypeConfig +template +struct MoeSmoothquantTypeConfig { - using XDataType = ck_tile::bf16_t; - using XScaleDataType = float; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using XDataType = InputType; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = OutputType; + using ComputeDataType = float; }; // runtime args @@ -37,7 +24,8 @@ struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel -template struct moe_smoothquant_traits_ { - using DataType = ck_tile::remove_cvref_t; + using InputType = ck_tile::remove_cvref_t; + using OutputType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); @@ -108,7 +97,8 @@ float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a); // This is the public API, will be generated by script struct moe_smoothquant_traits { - std::string data_type; + std::string in_type; // input type + std::string out_type; // output type }; float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh b/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh index 3bb62d37b9b60b02a6c0e951beaf358febb06c4a..e01f3de10a017ea2a9a1fd54bb55d6f9e7a970b5 100755 --- a/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh +++ b/example/ck_tile/14_moe_smoothquant/script/smoke_test.sh @@ -2,29 +2,31 @@ EXE=build/bin/tile_example_moe_smoothquant for pr_i in "fp16" "bf16" ; do -$EXE -prec=$pr_i -t=99 -h=13 -$EXE -prec=$pr_i -t=17 -h=16 -$EXE -prec=$pr_i -t=1 -h=100 -$EXE -prec=$pr_i -t=4 -h=128 -$EXE -prec=$pr_i -t=80 -h=127 -$EXE -prec=$pr_i -t=22 -h=255 -stride=256 -$EXE -prec=$pr_i -t=7 -h=599 -$EXE -prec=$pr_i -t=19 -h=512 -$EXE -prec=$pr_i -t=33 -h=313 -stride=1000 -$EXE -prec=$pr_i -t=11 -h=510 -$EXE -prec=$pr_i -t=171 -h=676 -stride=818 -$EXE -prec=$pr_i -t=91 -h=636 -$EXE -prec=$pr_i -t=12 -h=768 -stride=800 -$EXE -prec=$pr_i -t=100 -h=766 -stride=812 -$EXE -prec=$pr_i -t=31 -h=1024 -$EXE -prec=$pr_i -t=64 -h=1000 -stride=1004 -$EXE -prec=$pr_i -t=8 -h=1501 -$EXE -prec=$pr_i -t=3 -h=1826 -$EXE -prec=$pr_i -t=5 -h=2040 -$EXE -prec=$pr_i -t=7 -h=2734 -$EXE -prec=$pr_i -t=1 -h=3182 -$EXE -prec=$pr_i -t=9 -h=4096 -$EXE -prec=$pr_i -t=3 -h=8192 -$EXE -prec=$pr_i -t=1 -h=10547 -$EXE -prec=$pr_i -t=3 -h=17134 +for pr_o in "int8" "fp8" ; do +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=99 -h=13 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=17 -h=16 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=100 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=4 -h=128 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=80 -h=127 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=22 -h=255 -stride=256 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=599 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=19 -h=512 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=33 -h=313 -stride=1000 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=11 -h=510 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=171 -h=676 -stride=818 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=91 -h=636 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=12 -h=768 -stride=800 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=100 -h=766 -stride=812 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=31 -h=1024 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=64 -h=1000 -stride=1004 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=8 -h=1501 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=1826 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=5 -h=2040 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=2734 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=3182 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=9 -h=4096 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=8192 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=10547 +$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=17134 +done done diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md index dd566c16673e05e365a88c9c14292a0a35e676de..089e1de78e00aa9d6b9cd8b3627f79a08e26a44b 100644 --- a/example/ck_tile/15_fused_moe/README.md +++ b/example/ck_tile/15_fused_moe/README.md @@ -8,6 +8,9 @@ The benifit of this fused-moe: * much less kernel instance, easy to maintain # Implementation and feature support +## NOTES: +currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this. + ## moe-sorting this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) @@ -39,7 +42,7 @@ summary of the key design of this fused-moe operator: // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index 6bd7688d8a0bde729c1e9d303f95fb76db822b4d..1f2246fa4a9d34d9632bdd154e9c8c52b1f0e55a 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -8,14 +8,15 @@ struct fused_moe_args { - const void* a_ptr; // [m, k], input token - const void* a_scale_ptr; // [m, 1], token scale - const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) - const void* g_scale_ptr; // [e, 1, n], gate(up) scale - const void* d_scale_ptr; // [e, 1, k], down scale - const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input - void* o_ptr; // [m, k], output token (no need to do zeroing) + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP + void* o_ptr; // [m, k], output token (no need to do zeroing) const void* topk_ids_ptr; // [tokens, topk] const void* topk_weight_ptr; // [tokens, topk] @@ -26,7 +27,7 @@ struct fused_moe_args ck_tile::index_t block_m; // block_m, used to devide the input ck_tile::index_t hidden_size; // k - ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value ck_tile::index_t num_tokens; // input number of tokens for current iteration ck_tile::index_t num_experts; // number of groups ck_tile::index_t topk; // need this? @@ -45,8 +46,11 @@ struct fused_moe_traits std::string prec_sq; // smooth quant scale std::string prec_kw; // topk-weight data type int block_m; - int gate_only; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant + + bool local_expert_masking; // if mask experts as local expert }; float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moegemm.hpp b/example/ck_tile/15_fused_moe/fused_moegemm.hpp index b8e51475ad088ba01d037a854fd3e4afe7de6f37..8a1027c80cdea639797613c95c74536fdac2acf0 100644 --- a/example/ck_tile/15_fused_moe/fused_moegemm.hpp +++ b/example/ck_tile/15_fused_moe/fused_moegemm.hpp @@ -77,7 +77,8 @@ struct fused_moegemm_traits std::string prec_sq; // smooth quant scale std::string prec_kw; // topk-weight data type int block_m; - int gate_only; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp index 57dace9b41fff4cf87d7faf42ea19fdfd9a06d26..a3ff8c5bf7ef47f581853c4f5895cee599bf22cd 100644 --- a/example/ck_tile/15_fused_moe/fused_moesorting.hpp +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -10,7 +10,8 @@ struct fused_moesorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index bfc0ce409677a66c0b24e59d63b15313c21b5d9c..cf9ff2edbab5225adf33565037a64a1d22efc60b 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf return 1; }(); - auto t0 = fused_moesorting_trait{"int32", "fp32"}; + auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; auto a0 = fused_moesorting_args{ a.topk_ids_ptr, // const void* p_topk_ids; a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; a.sorted_token_ids_ptr, // void* p_sorted_token_ids; a.sorted_weight_ptr, // void* p_sorted_weights; a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; @@ -41,6 +42,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf t.prec_sq, t.prec_kw, t.block_m, + t.activation, t.gate_only, t.fused_quant}; auto a1 = fused_moegemm_args{ diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp index c1a4c495c3a2d4982289f1d43106cf175793bbad..49d29bad51c415dc5f0feb288fd4d5bd939f8b30 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp @@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: // clang-format off float r = -1; if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && - t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; r = fused_moegemm_(s, a); } else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && - t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; r = fused_moegemm_(s, a); } // clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 5872179ef71688328bda161d020535009ba2c4f2..343ddbed13ab4873b0120e0fb7fdeff8d8bfcc35 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) typename Ts_::BlockTile_1, typename Ts_::WarpPerBlock_0, typename Ts_::WarpTile_0>; - using f_problem = - ck_tile::FusedMoeGemmPipelineProblem; + + constexpr auto get_activation_ = []() { + if constexpr(Ts_::Activation == 0) + { + return ck_tile::element_wise::FastGeluAsm{}; + } + else + return ck_tile::element_wise::Silu{}; + }; + using f_act_ = ck_tile::remove_cvref_t; + + using f_problem = ck_tile::FusedMoeGemmPipelineProblem; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp index cc476685defeb85941d8b6dcc86f2b25469ed3e3..a7e53cc6548f2d3ddf3590b40c709549b7cf8d4d 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp @@ -15,7 +15,8 @@ template typename WarpPerBlock_, - typename WarpTile_, // seq<*,*,*>, used to select mfma + typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu ck_tile::index_t GateOnly_ = 0, ck_tile::index_t FusedQuant_ = 0> struct fmoe_ // traits, ugly name, only used for internal @@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal using WarpPerBlock_0 = ck_tile::remove_cvref_t; using WarpTile_0 = ck_tile::remove_cvref_t; - using BlockTile_1 = ck_tile::sequence; + using BlockTile_1 = ck_tile::sequence; using WarpPerBlock_1 = ck_tile::remove_cvref_t; using WarpTile_1 = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_; }; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp index 93f9c77869ef8b4a4d941d44b17833124d8a4560..5691743565ca5ffd67fbaccf3e227310b7c7f143 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp @@ -8,7 +8,18 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp index b8a823e8edf6596a5b449a5a7b5a87f236976542..74632df415ab28e0554b5f21c98a2b293c27e107 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp @@ -8,7 +8,19 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7ca24c5c9a2ce230640e30efa6b8b2ec42f445c4..7aedaa9317bbbfd7ef6562bc94f3de988cdc68d4 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,6 +3,12 @@ #include "fused_moesorting.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,67 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +105,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + (void)c_; + + MOE_SORTING_DISPATCH_EMASK_(r_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 2f44f903e9975a0b9dcaed4b4014140ade33d82b..95adcd684bd20389a9243235cace3eebe0c886b4 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[]) .insert( "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") + .insert("act", "0", "activation after first gemm. 0:gelu, 1:silu") .insert("balance", "0", "if set to 1, will try balance the expert in topk-ids(convenient for testing)") .insert("init", - "2", - "init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" + "1", + "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand " + "normalized[0, 1]" "normalized(slow)") .insert("seed", "11939", "seed used to do random") .insert("warmup", "5", "cold iter") @@ -135,30 +137,32 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t block_m = arg_parser.get_int("bm"); + ck_tile::index_t activation = arg_parser.get_int("act"); if(stride < 0) stride = hidden_size; - std::string prec_i = arg_parser.get_str("prec_i"); - std::string prec_w = arg_parser.get_str("prec_w"); - std::string prec_o = arg_parser.get_str("prec_o"); - std::string prec_st = arg_parser.get_str("prec_st"); - std::string prec_sw = arg_parser.get_str("prec_sw"); - std::string prec_sq = arg_parser.get_str("prec_sq"); - std::string prec_kw = arg_parser.get_str("prec_kw"); - prec_st = (prec_st == "auto") ? "fp32" : prec_st; - prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; - prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; - prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - int fused_quant = arg_parser.get_int("fquant"); - int gate_only = arg_parser.get_int("gate_only"); - int api = arg_parser.get_int("api"); - int balance = arg_parser.get_int("balance"); - int tp = arg_parser.get_int("tp"); - int init = arg_parser.get_int("init"); - uint32_t seed = arg_parser.get_uint32("seed"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int api = arg_parser.get_int("api"); + int balance = arg_parser.get_int("balance"); + int tp = arg_parser.get_int("tp"); + int init = arg_parser.get_int("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + bool local_expert_masking = false; // TODO... // w0 (Gate+Up or Gate only, N size) ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; @@ -194,11 +198,14 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::string(", st:") + std::to_string(stride); }(); - std::cout << "[" << api_str << "|" << prec_str << "]" - << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str - << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp - << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 - << ", go:" << gate_only << ", q:" << fused_quant << std::flush; + std::cout + << "[" << api_str << "|" << prec_str << "]" + << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str + << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp + << ", act:" + << activation + // << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 + << (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush; using TypeConfig = FusedMoeGemmTypeConfig; using ADataType = typename TypeConfig::ADataType; @@ -224,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor sy_host({shared_intermediate_size_1}); // smooth-quant ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort + ck_tile::HostTensor local_expert_mask_host({experts}); int max_num_tokens_padded = topk * tokens + experts * block_m - topk; ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); @@ -349,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sg_buf(sg_host); ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem local_expert_mask_buf(local_expert_mask_host); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host); @@ -370,8 +379,10 @@ bool run(const ck_tile::ArgParser& arg_parser) prec_sq, prec_kw, block_m, + activation, gate_only, - fused_quant}; + fused_quant, + local_expert_masking}; fused_moe_args args{a_buf.GetDeviceBuffer(), fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, @@ -380,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser) fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer() + : nullptr, o_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(), topk_weight_buf.GetDeviceBuffer(), @@ -389,7 +402,7 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_buf.GetDeviceBuffer(), block_m, hidden_size, - shared_intermediate_size_0, + intermediate_size / tp, tokens, experts, topk, @@ -408,39 +421,49 @@ bool run(const ck_tile::ArgParser& arg_parser) << cal_tbps(ave_time) << " TB/s" << std::flush; bool pass = true; +#define CPU_FUSED_MOE(act_type_) \ + ck_tile::reference_fused_moe(a_host, \ + g_host, \ + d_host, \ + sa_host, \ + sg_host, \ + sd_host, \ + sy_host, \ + o_host, \ + sorted_token_ids_host, \ + sorted_weight_host, \ + sorted_expert_ids_host, \ + num_sorted_tiles_host, \ + topk_ids_host, \ + block_m, \ + tokens, \ + experts, \ + hidden_size, \ + intermediate_size / tp, \ + topk, \ + gate_only) + if(do_validation) { ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, + local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, - block_m); - - ck_tile::reference_fused_moe( - a_host, - g_host, - d_host, - sa_host, - sg_host, - sd_host, - sy_host, - o_host, - sorted_token_ids_host, - sorted_weight_host, - sorted_expert_ids_host, - num_sorted_tiles_host, - topk_ids_host, block_m, - tokens, - experts, - hidden_size, - shared_intermediate_size_0, - topk, - gate_only); + local_expert_masking); + if(activation == 0) + { + CPU_FUSED_MOE(ck_tile::element_wise::Gelu); + } + else + { + CPU_FUSED_MOE(ck_tile::element_wise::Silu); + } auto o_dev = o_buf.ToHost(); // o_dev.savetxt("gpu-out.txt", "float"); @@ -457,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_moe_sorting( topk_ids_host, topk_weight_host, + local_expert_mask_host, sorted_token_ids_host, sorted_weight_host, sorted_expert_ids_host, num_sorted_tiles_host.mData[0], experts, - block_m); + block_m, + local_expert_masking); // done, preparing GPU buffer ck_tile::DeviceMem a_buf(a_host); @@ -491,6 +516,7 @@ bool run(const ck_tile::ArgParser& arg_parser) prec_sq, prec_kw, block_m, + activation, gate_only, fused_quant}; @@ -507,7 +533,7 @@ bool run(const ck_tile::ArgParser& arg_parser) sorted_expert_ids_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(), hidden_size, - shared_intermediate_size_0, + intermediate_size / tp, tokens, experts, topk, @@ -529,27 +555,14 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { - ck_tile::reference_fused_moe( - a_host, - g_host, - d_host, - sa_host, - sg_host, - sd_host, - sy_host, - o_host, - sorted_token_ids_host, - sorted_weight_host, - sorted_expert_ids_host, - num_sorted_tiles_host, - topk_ids_host, - block_m, - tokens, - experts, - hidden_size, - shared_intermediate_size_0, - topk, - gate_only); + if(activation == 0) + { + CPU_FUSED_MOE(ck_tile::element_wise::Gelu); + } + else + { + CPU_FUSED_MOE(ck_tile::element_wise::Silu); + } auto o_dev = o_buf.ToHost(); // o_dev.savetxt("gpu-out.txt", "float"); diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 9b4ed9a9e71263845121ad2f5ba4f439943ceaef..286fe4201da12a998971d26e9fba355ab9d97ae7 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -19,12 +19,9 @@ template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - constexpr bool kTilePermute = false; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; constexpr int kBlockPerCu = 1; @@ -41,53 +38,52 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - // Whether doing the CShuffle (transpose before the global memory), depending on the output - // layout. - constexpr bool CShuffleEpilogue = - std::is_same_v; - using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - - using GemmEpilogue = std::conditional_t< - CShuffleEpilogue, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; using CodegenGemmTraits = ck_tile::TileGemmTraits; - using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index f0c0c9efbacdda8f2b6f6affd999b46d627093aa..7b7e22160a2e24c5c3e71ca53028e8c9c393a7e5 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("batch_stride_a", "32768", "Batch A stride") .insert("batch_stride_b", "16384", "Batch B stride") @@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 4e7218b5b1db8ad701a21584882bc2de50b9956b..1105304e3e81ccc6fe2779553d37673f64956a0b 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -1,8 +1,28 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + template float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -17,6 +37,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_C, ck_tile::index_t batch_count, + ck_tile::index_t kbatch, int n_warmup, int n_repeat) { @@ -24,6 +45,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; @@ -79,6 +101,7 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); @@ -159,6 +182,7 @@ int run_batched_gemm_example_with_layouts(int argc, batch_stride_B, batch_stride_C, batch_count, + kbatch, n_warmup, n_repeat); @@ -175,10 +199,20 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::reference_batched_gemm( a_m_k, b_n_k, c_m_n_host_ref); - - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); - - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -236,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } @@ -256,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(a_layout == "R" && b_layout == "R") - { - return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") + // if(a_layout == "R" && b_layout == "R") + // { + // return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } + if(a_layout == "R" && b_layout == "C") { return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 14f3b4a5b88f1e4e8e987671a40db84193b0ee1f..03d5818179ad5a98b95171927772c92b0e89987a 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -15,18 +15,14 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -#include "utils.hpp" namespace { struct GroupedGemmKernelParam { - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - static const bool kTilePermute = false; - - static const ck_tile::index_t kOutputRank = 2; + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; static const int kBlockPerCu = 1; static const ck_tile::index_t M_Tile = 128; @@ -55,24 +51,6 @@ using CodegenGemmShape = using TilePartitioner = ck_tile::GemmTile1DPartitioner; -template -using GemmEpilogue = std::conditional_t< - std::is_same_v, - ck_tile::CShuffleEpilogue>, - ck_tile::Default2DEpilogue>>; - template using CodegenGemmTraits = ck_tile::TileGemmTraits>; -using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; - template using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1, - CodegenGemmPolicy>; + ck_tile::GemmPipelineAGmemBGmemCRegV1>; + +template +using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GroupedGemmKernelParam::M_Warp, + GroupedGemmKernelParam::N_Warp, + GroupedGemmKernelParam::M_Warp_Tile, + GroupedGemmKernelParam::N_Warp_Tile, + GroupedGemmKernelParam::K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; template using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; + GemmEpilogue>; }; // namespace -std::size_t GetWorkspaceSize(const std::vector& gemm_descs) +std::size_t get_workspace_size(const std::vector& gemm_descs) { return ::Kernel::GetWorkSpaceSize(gemm_descs); } @@ -128,7 +118,7 @@ float grouped_gemm(const std::vector& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" + std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 20ba74088438a16b907fb9f3ea94bd75cb89b652..2ffef95196c79bb16620d9f45364ddc53e00fe02 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -41,7 +41,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") .insert("a_layout", "R", "A tensor data layout - Row by default.") - .insert("b_layout", "R", "B tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") @@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -std::size_t GetWorkspaceSize(const std::vector& gemm_descs); +std::size_t get_workspace_size(const std::vector& gemm_descs); -float grouped_gemm_calc(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* p_workspace_); +float grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* p_workspace_); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 11faa6642cc9f9aa22f767f248e83f6dcf39b80a..080ea818c92d55ede0e423431ba2d3f7bfb86537 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -1,8 +1,35 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + template float invoke_gemm(int n_warmup, int n_repeat, @@ -11,7 +38,7 @@ float invoke_gemm(int n_warmup, { ck_tile::DeviceMem gemm_workspace; - gemm_workspace.Realloc(GetWorkspaceSize(args)); + gemm_workspace.Realloc(get_workspace_size(args)); float ave_time = grouped_gemm( args, @@ -108,16 +135,16 @@ int run_grouped_gemm_example_with_layouts(int argc, const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; - stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout); - stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout); - stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); + stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); - a_m_k_tensors.push_back( - ck_tile::HostTensor(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))); - b_k_n_tensors.push_back( - ck_tile::HostTensor(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); c_m_n_tensors.push_back(ck_tile::HostTensor( - f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -157,14 +184,25 @@ int run_grouped_gemm_example_with_layouts(int argc, { for(int i = 0; i < group_count; ++i) { - ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; } - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } return pass; @@ -188,10 +226,10 @@ int run_grouped_gemm_example(int argc, char* argv[]) { return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "R") - { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - } + // else if(a_layout == "R" && b_layout == "R") + // { + // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/example/ck_tile/17_grouped_gemm/utils.hpp b/example/ck_tile/17_grouped_gemm/utils.hpp deleted file mode 100644 index bb3cdf9fdc901cfe219d1a7f35bb9640efba1b9e..0000000000000000000000000000000000000000 --- a/example/ck_tile/17_grouped_gemm/utils.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -template -constexpr auto -f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) -{ - using namespace ck_tile::literals; - - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } -} -template -constexpr auto -f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) -{ - if(stride == 0) - { - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; -} diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a08fcebb74f3828bc28f5d7cab71c4bad9344711 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -0,0 +1,9 @@ +set(TARGET_NAME tile_example_batched_transpose) +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d0583e75292c765f3eb8d81a41b72a9394a63105 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/README.md @@ -0,0 +1,27 @@ +# Batched Transpose +This folder contains example for batched Transpose using ck_tile tile-programming implementation. Currently, it supports the batched transpose with NCHW to NHWC or NHWC to NCHW. So in this way from NCHW you could transpose to either NHWC or NWCH(two transposes). Now the transpose read with single data point. We would soon put it in vectorized transpose. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the transpose executable +make tile_example_batched_transpose -j +``` +This will result in an executable `build/bin/tile_example_batched_transpose` + +## example +``` +args: + -N input batch size (default:2) + -C input channel size. (default:16) + -H input height size. (default:1) + -W input width size. (default:16) + -v whether do CPU validation or not (default: 1) + -layout_in input tensor data layout - NCHW by default + -layout_out output tensor data layout - NHWC by default + -seed seed to be used, -1 means random every time (default:-1) + -k_name t to 1 will print kernel name (default:0) +``` \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77d768fe3fa67b6709c3f1edb05a734debdb79ab --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "batched_transpose_example.hpp" +#include + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_block_h = (a.height + block_y - 1) / block_y; + uint32_t dim_block_w = (a.width + block_x - 1) / block_x; + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = dim_block_h; + a.dim_block_w = dim_block_w; + + using block_tile = ck_tile::sequence; + using warp_tile = ck_tile::sequence; + using thread_tile = ck_tile::sequence; + + using ts_problem = + ck_tile::BatchedTransposeProblem; + using ts_pipeline = ck_tile::BatchedTransposePipeline; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \ + F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \ + F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \ + F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1) + +// Macro that defines one static function per line +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ + } + +FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp16") + { + return transpose_fn_fp16_16_16_8_8_1_1(a, s); + } + else if(t.type == "bf16") + { + return transpose_fn_bf16_16_16_8_8_1_1(a, s); + } + else if(t.type == "fp32") + { + return transpose_fn_fp32_16_16_8_8_1_1(a, s); + } + else if(t.type == "int8") + { + return transpose_fn_int8_16_16_8_8_1_1(a, s); + } + return -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48fc2859bfb4eaadcb9257c69f5fe069b9640cb9 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batched_transpose_example.hpp" + +#if 0 +template +void dump_host_tensor_4d(const ck_tile::HostTensor& x) +{ + auto len = x.get_lengths(); + assert(len.size() == 4); + std::cout << "["; + for(size_t i = 0; i < len[0]; i++) + { + std::cout << i << ": ["; + for(size_t j = 0; j < len[1]; j++) + { + std::cout << j << ": ["; + for(size_t k = 0; k < len[2]; k++) + { + std::cout << k << ": ["; + for(size_t v = 0; v < len[3]; v++) + { + if constexpr(std::is_same_v) + { + auto m = + ck_tile::type_convert(x(std::vector{i, j, k, v})); + + std::cout << m; + if(v != len[3] - 1) + std::cout << ","; + } + else + { + std::cout << x(std::vector{i, j, k, v}) << " "; + } + } + std::cout << "]" << std::endl; + } + std::cout << "]" << std::endl; + } + std::cout << std::endl; + } + std::cout << "--------------------" << std::endl; +} +#endif + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("N", "2", "input batch size. ") + .insert("C", "16", "input channel size.") + .insert("H", "1", "input height size.") + .insert("W", "16", "input width size. ") + .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") + .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "t to 1 will print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run_batched_transpose(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string prec = args.get_str("pr"); + int N = args.get_int("N"); + int C = args.get_int("C"); + int H = args.get_int("H"); + int W = args.get_int("W"); + std::string layout_in = args.get_str("layout_in"); + std::string layout_out = args.get_str("layout_out"); + int seed = args.get_int("seed"); + + int dim_in[4], dim_out[4]; + int stride_dim_in[4], stride_dim_out[4]; + bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; + bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; + assert(nchw2nhwc != nhwc2nchw); + (void)nhwc2nchw; + + dim_in[0] = N; + dim_in[1] = nchw2nhwc ? C : H; + dim_in[2] = nchw2nhwc ? H : W; + dim_in[3] = nchw2nhwc ? W : C; + dim_out[0] = N; + dim_out[1] = nchw2nhwc ? H : C; + dim_out[2] = nchw2nhwc ? W : H; + dim_out[3] = nchw2nhwc ? C : W; + stride_dim_in[0] = C * H * W; + stride_dim_in[1] = nchw2nhwc ? H * W : C * W; + stride_dim_in[2] = nchw2nhwc ? W : C; + stride_dim_in[3] = 1; + stride_dim_out[0] = C * H * W; + stride_dim_out[1] = nchw2nhwc ? C * W : H * W; + stride_dim_out[2] = nchw2nhwc ? C : W; + stride_dim_out[3] = 1; + + if(seed < 0) + { + seed = std::time(nullptr); + } + + ck_tile::HostTensor x_host( + {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, + {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); + ck_tile::HostTensor y_host( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + auto trait = batched_transpose_trait{prec, layout_in}; + + uint32_t height = nchw2nhwc ? C : H * W; + uint32_t width = nchw2nhwc ? H * W : C; + + batched_transpose_kargs karg = [&]() { + batched_transpose_kargs a_; + a_.p_input = x_dev.GetDeviceBuffer(); + a_.p_output = y_dev.GetDeviceBuffer(); + a_.batch = N; + a_.height = height; + a_.width = width; + return a_; + }(); + + ck_tile::stream_config sc{nullptr, true}; + + auto ms = batched_transpose(trait, karg, sc); + + std::size_t num_operations = N * C * H * (W - 1); + std::size_t num_bytes = N * C * H * W * sizeof(Type); + + float ave_time = ms * 1E-3; + float gb_per_sec = num_bytes / ms * 1.E-6; + float tflops = static_cast(num_operations) / ms * 1.E-6; + + std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H + << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out + << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" + << gb_per_sec << " GB/s, " << std::endl; + + printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", + prec.c_str(), + N, + C, + H, + W, + layout_in.c_str(), + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + + if(ms < 0) + { + return false; + } + + y_dev.FromDevice(y_host.data()); + + bool rtn = true; + if(validate) + { + // this host buffer will not copy to GPU, so no need use stride + ck_tile::HostTensor y_ref( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); + + auto [rtol, atol] = get_elimit(""); + + rtn &= ck_tile::check_err( + y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); + } + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string prec = args.get_str("pr"); + + bool r = true; + if(prec.compare("fp32") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("fp16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("bf16") == 0) + { + r &= run_batched_transpose(args); + } + else if(prec.compare("int8") == 0) + { + r &= run_batched_transpose(args); + } + + return r ? 0 : -1; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp new file mode 100644 index 0000000000000000000000000000000000000000..487ddc17b227d8db8448cd8b363037598196f128 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/batched_transpose.hpp" + +#include +#include + +#pragma once + +struct batched_transpose_trait +{ + std::string type; + std::string layout; +}; + +struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs +{ +}; + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s); diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..fdfef2cea8f25fd644619045d95147f991f1e3f7 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_batched_transpose + +for pr in "fp32" "fp16" "int8" ; do +$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +done diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 296eb1ecefe8b29532856a056671305b2f5dee00..7f4ba2ed359503de2aae3227976165e41befd677 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(35_batched_transpose) diff --git a/include/ck/README.md b/include/ck/README.md index bff689f6b05fe2ef777ad5f48ec5a49ed7aa3e71..92d5a510873685549f46695eeade126aa5c08186 100644 --- a/include/ck/README.md +++ b/include/ck/README.md @@ -1,19 +1,23 @@ [Back to the main page](../../README.md) # Composable Kernel supported operations ## Supported device operations -* [Average pooling]() -* [Batched contraction]() -* [Batched gemm]() -* [Batchnorm]() -* [CGEMM]() -* [Contraction]() -* [Convolution]() -* [Image to Column and Column to Image]() -* [Elementwise]() -* [GEMM]() -* [Max pooling]() -* [Reduce]() -* [Normalization]() -* [Permute]() -* [Put]() -* [Softmax]() + + + + + + + + +* [GEMM](../../client_example/01_gemm/README.md) +* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md) +* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md) +* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md) + + + + + + + + diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 999eb0229c5ab16f19084a53b1b01266ed7e48c1..1ec0c6bc2338d23cbdb076d097d317cbb11ded52 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/config.h" #include "ck/utility/env.hpp" - +#ifndef CK_CODE_GEN_RTC #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -14,10 +14,12 @@ // environment variable to enable logging: // export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - +#endif // to do: add various levels of logging with CK_LOG_LEVEL +#ifndef CK_TIME_KERNEL #define CK_TIME_KERNEL 1 +#endif // constant address space for kernel parameter // https://llvm.org/docs/AMDGPUUsage.html#address-spaces @@ -53,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // define general macros for various architectures #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) + defined(__gfx942__) || defined(__gfx950__) #define __gfx9__ #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) @@ -155,9 +157,22 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // LDS direct loads using inline assembly #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 +// set rounding to nearest even as default for bf16 conversions +#define CK_USE_RNE_BF16_CONVERSION 1 + // set rounding to nearest even as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 0 +// set rounding to nearest even as default for f6 conversions +#define CK_USE_SR_F6_CONVERSION 0 + +// set rounding to nearest even as default for f4 conversions +#define CK_USE_SR_F4_CONVERSION 0 + +// shuffle pk_i4 values during conversion to optimize number of binary +// operations +#define CK_USE_PK4_LAYOUT_SHUFFLE 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -230,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// denorm test fix, required to work around dissue -#ifndef CK_WORKAROUND_DENORM_FIX -#define CK_WORKAROUND_DENORM_FIX 0 +// denorm test fix, necessary for gfx90a +#ifndef CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // CK_GFX90A_DENORM_WORKAROUND +// Enable only for gfx90a +#if defined(__gfx90a__) +#if CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 1 +#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1 #else -// enable only for gfx90a -#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) -#endif // CK_WORKAROUND_DENORM_FIX +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // gfx90a // set flag to 1 to build deprecated instances #define CK_BUILD_DEPRECATED 1 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 2c37300e9b6c84576b53e5d35ad5699debe112b0..994e60025d08f2997769e2c1bb85d4e0f1135dfa 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -97,6 +97,10 @@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #endif +#ifndef CK_ENABLE_DPP_KERNELS +#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@ +#endif + // // CK kernels which support XDL (MI series) // @@ -127,6 +131,10 @@ #cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ #endif +#ifndef CK_USE_NATIVE_MX_SUPPORT +#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index f5c4b43ad2464117d476fed38329b0781f77e160..05dc491af779d828cf38c7510c0f510c19ace8d5 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -55,20 +55,21 @@ inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" || + ck::get_device_name() == "gfx950"; } inline bool is_bf16_atomic_supported() { return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_gfx101_supported() diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 08bfefb87f072876e5c6cdbb45db72181056dcc9..d33ecaeef8b8f6402b19e25ddeda90e3d2dee422 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -26,6 +26,7 @@ namespace utils { template double get_relative_threshold(const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || @@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) template double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; @@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -450,5 +452,54 @@ check_err(const Range& out, return res; } +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f4_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 0.5, + double atol = 0.5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + } // namespace utils } // namespace ck diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 2006918cd1bb3dd25e5b5321481d84a895e99c96..87a912e8886f1bd97bfd52b6231eda607c0010ee 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,10 +45,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) else os << delim; - if constexpr(std::is_same_v || std::is_same_v) + using RangeType = ck::remove_cvref_t; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) { os << ck::type_convert(v); } + else if constexpr(std::is_same_v) + { + const auto packed_floats = ck::type_convert(v); + const ck::vector_type vector_of_floats{packed_floats}; + os << vector_of_floats.template AsType()[ck::Number<0>{}] << delim + << vector_of_floats.template AsType()[ck::Number<1>{}]; + } else { os << static_cast(v); @@ -267,18 +276,18 @@ struct Tensor using Data = std::vector; template - Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(std::initializer_list lens) : mDesc(lens), mData(GetElementSpaceSize()) { } template Tensor(std::initializer_list lens, std::initializer_list strides) - : mDesc(lens, strides), mData(mDesc.GetElementSpaceSize()) + : mDesc(lens, strides), mData(GetElementSpaceSize()) { } template - Tensor(const Lengths& lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) + Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize()) { } @@ -288,7 +297,7 @@ struct Tensor { } - Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} + Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} template Tensor CopyAsType() const @@ -348,7 +357,17 @@ struct Tensor std::size_t GetElementSize() const { return mDesc.GetElementSize(); } - std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } + std::size_t GetElementSpaceSize() const + { + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return (mDesc.GetElementSpaceSize() + 1) / 2; + } + else + { + return mDesc.GetElementSpaceSize(); + } + } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } @@ -495,29 +514,64 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mDesc.GetOffsetFromMultiIndex(is...) / 2; + } + else + { + return mDesc.GetOffsetFromMultiIndex(is...); + } } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } } T& operator()(std::vector idx) { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + if constexpr(ck::is_same_v, ck::pk_i4_t>) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; + } + else + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } } typename Data::iterator begin() { return mData.begin(); } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index ab9f01b53cb422d3450530c85ca53d4c4a0af2f2..274051da83040f760ac1343babccc5b2ff6f2e76 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -69,6 +69,18 @@ struct GeneratorTensor_1 }; #endif +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template <> struct GeneratorTensor_1 { @@ -81,6 +93,20 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int t = value + 8; + ck::pk_i4_t r = ((t << 4) + t) & 0xff; + return r; + } +}; + template struct GeneratorTensor_2 { @@ -121,6 +147,22 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int hi = std::rand() % (max_value - min_value) + min_value + 8; + int lo = std::rand() % (max_value - min_value) + min_value + 8; + ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; + return r; + } +}; + #if defined CK_ENABLE_FP8 template <> struct GeneratorTensor_2 @@ -153,6 +195,20 @@ struct GeneratorTensor_2 }; #endif +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + template struct GeneratorTensor_3 { @@ -223,6 +279,23 @@ struct GeneratorTensor_3 }; #endif +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index d719ef9760d79297600d7524167eba78cd137831..ef2bedd65cefadf8f68a8eefcdb282f742fab563 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr X GetAsType(Idx) const @@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer // Idx is for S, not X. Idx should be aligned with X template ::value && + typename enable_if<(has_same_scalar_type::value || !is_native_type()) && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr void SetAsType(Idx, X x) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea0c511da37d5af8f263cb044178ea9571d1e22c --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer +}; + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4246f4a44e76b4c3f2554fcdfc90019771310f7b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..776f66dbbb8f84c61a5882b92a0f8099db76d1d0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -0,0 +1,1248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t WgpPerCU = + (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + const BScaleGridDesc& b_scale_grid_desc, + // BScaleThreadCopy + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + ignore = num_loop_per_scale; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + auto c_thread_buf_per_scale = remove_cvref_t(); // need? + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) + // { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d1be88dd632fa3d878cdd6479d8e7445ceb8217f --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -0,0 +1,530 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1); + constexpr auto num_scale_krepeat = KRepeat / num_scale_k_block; + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 2) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f35c7a97cc323e438ded5e120ed1b5c39a3d3474 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 4 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopUnroll = 2; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr void HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + 0); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(2 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_scale_thread_bufs(I0)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(3 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + // B scale copy + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(lds_read_reg_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 4 + mfma_reg_buf.value) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0); + ReadCompFunc(I0, I0, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I1, I1, I0); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp index 1c4de5ed3153f2cc8ca7f4a8ccf57741085fba40..0a0bcbac38c2a732f0d9c2f342797e19c0400845 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2 } template - using is_tuple = decltype(std::declval().IsTuple()); + using is_tuple = decltype(ck::declval().IsTuple()); template __device__ void RunWrite(const DstDescs& dst_descs, diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 0eef827a5b5018efeee94d11f1b768c739d85648..cf20025d46e1ac0ba0de1c529b3fea5d90621915 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif namespace ck { namespace tensor_operation { @@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization Filter3x3, }; +#ifndef CK_CODE_GEN_RTC inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) { switch(s) @@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 736e241fdfe4538508403dadd61c5bb1f147c3bd..774982d905fb49551654b574787a5b0fc2bbde81 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,19 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include #include #include #include - #include "ck/stream_config.hpp" +#endif namespace ck { namespace tensor_operation { namespace device { +#ifndef CK_CODE_GEN_RTC #define GET_OBJECT_NAME_IMLP \ std::optional GetObjectName() const override \ { \ @@ -41,7 +43,9 @@ namespace device { } #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL +#endif +#ifndef CK_CODE_GEN_RTC struct BaseArgument { BaseArgument() = default; @@ -66,13 +70,14 @@ struct BaseInvoker virtual ~BaseInvoker() {} }; +#endif struct BaseOperator { BaseOperator() = default; BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; - +#ifndef CK_CODE_GEN_RTC virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } @@ -100,7 +105,7 @@ struct BaseOperator assert(p_arg); p_arg->p_workspace_ = p_workspace; } - +#endif virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index 6cc2c7bb2f6c176f2d84fdda4be2140db5564360..fcb46082933771ce7f2a24485d30ea54e75f0680 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +template +struct DeviceBatchedGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t BatchStrideScaleB, + const void* p_b_scale, + ck::index_t Batch, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + template MakeInvokerPointer() = 0; + + virtual bool GetPermuteA() = 0; + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; }; template MakeInvokerPointer() = 0; }; +template +struct DeviceGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + const void* p_b_scale, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index 184efbbd68ecea5b3c2b36f52764690e3ad316da..8c9b768a8b9c75ad5da9859c72d53b3d98a5967e 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#ifndef CK_CODE_GEN_RTC #include +#endif #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" @@ -13,8 +15,13 @@ namespace ck { namespace tensor_operation { namespace device { +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif /** * \brief Grouped Convolution Forward @@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator static constexpr index_t NumDTensor = DsDataType::Size(); static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); - +#ifdef CK_CODE_GEN_RTC + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; +#else // If DataType is tuple, user has to pass std::array with pointers. using APointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; using BPointers = - std::conditional_t&, const void*>; + ck::conditional_t&, const void*>; +#endif + +#ifndef CK_CODE_GEN_RTC /** * \brief Make argument pointer for grouped conv fwd. @@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator const CDEElementwiseOperation& cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 0bb45b18c3e19b2ec5f9347c1e811d8734ee45a9..997dcb75a6faee60912f23d55f77c54a60cd2ff4 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,6 +29,7 @@ enum struct GemmSpecialization MNKOPadding, }; +#ifndef CK_CODE_GEN_RTC inline std::string getGemmSpecializationString(const GemmSpecialization& s) { switch(s) @@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s) default: return "Unrecognized specialization!"; } } +#endif } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 180e32c8b6b8ee41577b3f9700614990a102b2ee..00518b369f4193327806d062e6cd9662ebb6489a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -3,11 +3,17 @@ #pragma once +#ifndef CK_CODE_GEN_RTC #include #include #include #include #include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -15,15 +21,12 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/io.hpp" namespace ck { namespace tensor_operation { @@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -259,8 +261,13 @@ __global__ void } // namespace +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. @@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it - using GemmADataType = std::conditional_t, ADataType>; - using GemmBDataType = std::conditional_t, BDataType>; + using GemmADataType = ck::conditional_t, ADataType>; + using GemmBDataType = ck::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ @@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + ck::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. - using APointers = - std::conditional_t&, const void*>; - using BPointers = - std::conditional_t&, const void*>; + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not // in initializer list what is required for single const pointer). using AGridPointer = remove_cvref_t< @@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static_for<0, NumDTensor, 1>{}([&](auto i) { using DLayout = remove_cvref_t>; - // FIXME: layout if constexpr(is_same_v || is_same_v || is_same_v || is_same_v || @@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const BElementwiseOperation& b_element_op, const CDEElementwiseOperation& cde_element_op) { - std::array a_g_n_c_wis_lengths_i32; - std::array a_g_n_c_wis_strides_i32; - std::array b_g_k_c_xs_lengths_i32; - std::array b_g_k_c_xs_strides_i32; - std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; - std::array, NumDTensor> ds_g_n_k_wos_strides_i32; - std::array e_g_n_k_wos_lengths_i32; - std::array e_g_n_k_wos_strides_i32; - std::array conv_filter_strides_i32; - std::array conv_filter_dilations_i32; - std::array input_left_pads_i32; - std::array input_right_pads_i32; + ck::Array a_g_n_c_wis_lengths_i32; + ck::Array a_g_n_c_wis_strides_i32; + ck::Array b_g_k_c_xs_lengths_i32; + ck::Array b_g_k_c_xs_strides_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_i32; + ck::Array e_g_n_k_wos_lengths_i32; + ck::Array e_g_n_k_wos_strides_i32; + ck::Array conv_filter_strides_i32; + ck::Array conv_filter_dilations_i32; + ck::Array input_left_pads_i32; + ck::Array input_right_pads_i32; array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 64aa398d531e634c54c3d55d591789f72943bd0e..d53fbca4eaf4b62ec419a37071881ee192130b48 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,8 +56,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d06eab1264684b595dc87dc1cf91a9f2b44a5056..25a9d7f96dea69a339b977c2cc7a9e6a973317c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -74,8 +74,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index e950169ccfe58887351c4bfbc82af85a061287cf..985752796bfa5f5a62bf88c881840af0fa4e95ee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -108,7 +107,7 @@ __global__ void ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index d6b92bc97a8d3eed0d0ce50ad26a9169c498671a..630f143260495a9aeb11f0c764a03ef21b348756 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -83,8 +83,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 6ab1669e30f6962a9d88845ea9a89806ad844f56..f6c228fb7b448e12c07eae08b7879f4e94e1c398 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -68,8 +68,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 34b1d503afe78e324d43f5fb7df6531809756e99..30ae72a63e98dfb8f86b7ccf45f620b4ba3633a8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -59,8 +59,7 @@ __global__ void const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index e178b8f5252781ead149f5d2b78f8fc53125a3af..2662e5c360b2d8e16082e62e84b9779ef913b00d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -67,8 +67,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -127,7 +126,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 9af1a447814d530a06b5de7083ad691c03e6962e..ea5a5d0e16962376fb6aa7eba48c67ab1a5f6c70 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -62,8 +62,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -112,7 +111,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 @@ -611,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } + static constexpr bool + IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -625,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); } // polymorphic @@ -765,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } + + template + struct Descriptor + { + template + static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) + { + const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) + { + const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) + { + const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) + { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; + CGridDesc_M_N c_grid_desc_m_n; + C0MatrixMask c0_matrix_mask; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_descriptor_mblock_mperblock_nblock_nperblock; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + bool has_main_k_block_loop = true; + bool is_valid = false; + + constexpr Descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, + b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, + b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, + c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + c_grid_descriptor_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + c0_matrix_mask{c.GetLength(I1)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} + { + } + + constexpr bool IsValid() const { return is_valid; } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + return Descriptor( + a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const float scale, + const ADataType* __restrict__ p_a_grid, + const ADataType* __restrict__ p_b_grid, + const ADataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid) + { +#ifndef __HIPCC_RTC__ + assert(desc.is_valid); +#endif + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + AccElementwiseOperation acc_element_op{scale}; + + if(desc.has_main_k_block_loop) + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + else + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 6be2ffbdd781d27cc43845b7da311bc8a929e42b..494524b6f0588c33687ffd4a434977a70a173a10 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -52,8 +52,7 @@ __global__ void #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..963f0edd08813251f2ecf06e2bc4e847f06827e2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -0,0 +1,1007 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale + : public DeviceBatchedGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_) / BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + __host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideScaleB_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + StrideScaleB_, + p_b_scale_grid_, + KBatch_, // KBatch + a_element_op_, + b_element_op_, + c_element_op_), + Batch(Batch_), + compute_ptr_offset_of_batch( + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_) + { + } + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const BScaleDataType* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + p_b_scale, + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const void* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + static_cast(p_b_scale), + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 6e6921351356740f899daab6b6343a46008072fc..8aa20f7ad476e0292a482371d5a2dd27ed5ae35f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -55,8 +55,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -97,7 +96,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index 811f1ae9396de6f9166c4d6c78dcc9c757e42996..b9467ac1945cad037d30083e09d124ceac50bfbd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index eaafd7d5c5eb05b69447ba1b6157da1cab25e230..47fb630ea9037432b506cb30ce5699b645e22365 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -63,8 +63,7 @@ __global__ void const Block2ETileMap block_2_etile_map, index_t NRaw) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; GridwiseGemmWelford::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index bb2db930c8ea7b1ed34b17a7cde56a2c7c6daafc..c048e7249c305fba0d2cfcfab1a0e78199e960f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 77ed9625c5d6429967ff9636ce9829a28c3b7345..e6466a487b11826929faf2454796bb457b2710ef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -52,8 +52,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp old mode 100755 new mode 100644 index cfd9a12047afb98233cbec35c4aca55b5784db8a..26be5cfc613a1fc9a18341709b4cc17e532b19ad --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -469,7 +469,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 && + arg.Streamk_sel > 0) + { + return false; + } if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 5a90c09d1dd9c8c8648a048ebd4bcb702eaa68c6..a8cf681995e5947ded2714960a171085ba4c952a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -64,7 +64,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; + ComputeTypeB, + PermuteA, + PermuteB>; using Argument = typename GridwiseGemm::Argument; @@ -134,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -645,6 +650,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2(p_arg)); } + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, @@ -736,7 +746,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index cc022b89c5ed472758300a1cd4e19cc5b2924f09..1cf58fec258123dee7957534061e14f0f056479a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,8 +37,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c8c58d5d8589da6236d3da73a0f5f303fc540d7e..99bd3be15df3688c2b7688d2acde6c74bbf03fff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -87,8 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a7df1c9d570c81af44f226709d1334f7f254810c..57c4b1a5cf46c411084fcbe74b5b84c8f89f4ace 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -103,7 +102,7 @@ __global__ void compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); compute_ptr_offset_of_batch.GetCPtrOffset(0); -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template 1)) { return false; } - if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 && + NumGroupsToMerge > 1)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 996107343d510f6b0c5e418881fa90ea5b60e4b2..ef87bb52ae37819484fe22759f59ae30425bac27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { return false; } + if(!is_bf16_atomic_supported() && std::is_same_v) + { + return false; + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWC_GKXC_GNWK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f21a45938f5079a4093c6dfddf41e3a7390d5e45..02ca8f42e496ccc69f1dd5a5e34687c47e507d7a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,6 +9,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -98,8 +99,7 @@ __global__ void const ComputePtrOffsetOfG compute_ptr_offset_of_groups, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -121,19 +121,6 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); - if constexpr(is_same_v) - { - a_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - b_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(is_same_v) - { - cde_element_op.InitUnaryOpPtrOnDevice(); - } - if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; @@ -225,9 +212,13 @@ __global__ void } } // namespace - +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else template using is_tuple = decltype(std::declval().IsTuple()); +#endif // // @brief Device Convolution operation. diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 589a0daa99d4070ebf72d64961e239cfdc7c5488..9363d7ecb9a0f411412780db91a46c19e7bbbbfa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -9,6 +9,7 @@ #include #include +#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -117,7 +118,7 @@ __global__ void c_grid_desc_mblock_mperblock_nblock_nperblock); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ::value, bool>::type = false> @@ -444,6 +445,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) os << Layout::name; return os; } +#endif } // namespace tensor_layout } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index c87c90a91dd9f731713553e0d56bd24a9a0b25d2..530876650ee5a676cb7072a8d0f110708480f048 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -340,8 +340,8 @@ struct Bilinear }; template <> - __host__ __device__ constexpr void operator()( - std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x0, const int8_t& x1) const { y = type_convert(alpha_ * type_convert(x0) + beta_ * type_convert(x1)); diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 42ec69a4b69d1b544a1b9a8b3a7e2011a40b49dd..b57ae22172ca4b902710641c28f30877c96044eb 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -543,7 +543,7 @@ struct NormalizeInInfer const T3& gamma, const T4& beta) const { - static_assert(std::is_same::value || std::is_same::value, + static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); using ck::type_convert; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index dad4c2c771f829f5b543d03d15aecdcde0ceabb9..be4e68bfface6aa8e0d66f32515500ff551f5c9c 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,36 +7,203 @@ #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/amd_inline_asm.hpp" #include namespace ck { + +// Fast int4x4 to half8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +// Convert lower part of packed int4 -> int4 to half +__device__ inline half4_t i4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<0>{})) + : "v"(res.template AsType()(Number<0>{})), "v"(scale)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<1>{})) + : "v"(res.template AsType()(Number<1>{})), "v"(scale)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline bhalf4_t i4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + vector_type res; + res.template AsType()(Number<0>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.template AsType()(Number<1>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res.template AsType()[Number<0>{}]; +} + namespace tensor_operation { namespace element_wise { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnon-virtual-dtor" -struct UnaryOpBase +struct PassThroughPack8 { - public: - __host__ __device__ ~UnaryOpBase() = default; + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_half4(bit_cast(x) >> 8); - __host__ __device__ constexpr UnaryOpBase() = default; - __host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default; - __host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default; - __host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default; - __host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default; + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; - __host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0; + result.template AsType()(Number<0>{}) = i4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_bhalf4(bit_cast(x) >> 16); - __host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0; + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; - __host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0; + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); - __host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0; + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; +}; - __host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0; +struct DequantPack8 +{ + template + __host__ __device__ void operator()(Y& y, const X& x, const Z& z) const; + + __host__ __device__ constexpr void + operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4_scale(bit_cast(x), z); + result.template AsType()(Number<1>{}) = + i4_to_half4_scale(bit_cast(x) >> 8, z); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } - __host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0; + constexpr const static bool is_pack8_invocable = true; }; struct PassThroughPack2 @@ -44,38 +211,49 @@ struct PassThroughPack2 template __host__ __device__ void operator()(Y& y, const X& x) const; - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const + __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const { auto t = type_convert(x); y = type_convert(t); } - constexpr const static bool is_pack2_invocable = true; -}; - -struct PassThrough final : public UnaryOpBase -{ - __host__ __device__ constexpr PassThrough() = default; - __host__ __device__ constexpr PassThrough(const PassThrough&) = default; - __host__ __device__ constexpr PassThrough(PassThrough&&) = default; - __host__ __device__ PassThrough& operator=(const PassThrough&) = default; - __host__ __device__ PassThrough& operator=(PassThrough&&) = default; - __host__ __device__ ~PassThrough() = default; - __host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; } - - __host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; } + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + uint8_t x_u8 = ck::bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; } + auto l_f16 = ck::type_convert(x_l); + auto h_f16 = ck::type_convert(x_h); - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; } + y = {l_f16, h_f16}; +#else + uint32_t t = ck::bit_cast(x); + y = ck::bit_cast(t); +#endif + } - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; } + constexpr const static bool is_pack2_invocable = true; +}; +struct PassThrough +{ template __host__ __device__ void operator()(Y& y, const X& x) const; + template <> + __host__ __device__ void operator()(pk_i4_t& y, const pk_i4_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(float& y, const double& x) const { @@ -88,6 +266,18 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const float& x) const { @@ -136,6 +326,12 @@ struct PassThrough final : public UnaryOpBase y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(half_t& y, const int8_t& x) const { @@ -248,7 +444,7 @@ struct PassThrough final : public UnaryOpBase template <> __host__ __device__ void operator()(bf8_t& y, const half_t& x) const { - y = ck::type_convert(x); + y = type_convert(x); } }; @@ -321,21 +517,21 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const { - y = ck::type_convert(ck::type_convert(x) * scale_); + y = type_convert(type_convert(x) * scale_); } template <> __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::type_convert(scale_) * x; + y = type_convert(scale_) * x; }; template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - const float x_tmp = ck::type_convert(x); + const float x_tmp = type_convert(x); const float y_tmp = scale_ * x_tmp; - y = ck::type_convert(y_tmp); + y = type_convert(y_tmp); }; template <> @@ -353,7 +549,7 @@ struct Scale template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { - y = ck::type_convert(scale_ * ck::type_convert(x)); + y = type_convert(scale_ * type_convert(x)); }; float scale_; @@ -369,7 +565,7 @@ struct ScaleAndResetNaNToMinusInfinity template <> __host__ __device__ void operator()(float& y, const float& x) const { - y = ck::math::isnan(x) ? -ck::NumericLimits::Infinity() : scale_ * x; + y = math::isnan(x) ? -NumericLimits::Infinity() : scale_ * x; }; float scale_; @@ -435,45 +631,21 @@ struct UnarySquare }; }; -struct UnaryAbs final : public UnaryOpBase +struct UnaryAbs { - __host__ __device__ constexpr UnaryAbs() = default; - __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; - __host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default; - __host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default; - __host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default; - __host__ __device__ ~UnaryAbs() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::abs(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::abs(x); - } - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - y = ck::math::abs(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::abs(x); - } + y = math::abs(x); + }; + template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { y = ck::type_convert(ck::math::abs(ck::type_convert(x))); @@ -488,49 +660,28 @@ struct UnarySqrt static_assert(is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sqrt(x); + y = math::sqrt(x); }; }; -struct Relu final : public UnaryOpBase +struct Relu { - __host__ __device__ constexpr Relu() = default; - __host__ __device__ constexpr Relu(const Relu&) = default; - __host__ __device__ constexpr Relu(Relu&&) = default; - __host__ __device__ Relu& operator=(const Relu&) = default; - __host__ __device__ Relu& operator=(Relu&&) = default; - __host__ __device__ ~Relu() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = x > 0 ? x : 0; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); y = x > 0 ? x : 0; } - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - float x_f32 = ck::type_convert(x); + float x_f32 = type_convert(x); float y_f32 = x_f32 > 0 ? x_f32 : 0; - y = ck::type_convert(y_f32); + y = type_convert(y_f32); } }; @@ -546,7 +697,7 @@ struct FastGelu template __device__ void operator()(Y& y, const X& x) const; - +#ifndef CK_CODE_GEN_RTC template <> __host__ void operator()(float& y, const float& x) const { @@ -557,6 +708,7 @@ struct FastGelu const float emu = exp(u); y = x / (1.f + emu); } +#endif // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> @@ -568,7 +720,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __ocml_exp_f32(u); - y = x * ck::math::rcp(1.f + emu); + y = x * math::rcp(1.f + emu); } template <> @@ -666,59 +818,24 @@ struct Gelu } template <> - __host__ __device__ void operator()(ck::half_t& y, - const ck::half_t& x) const + __host__ __device__ void operator()(half_t& y, const half_t& x) const { - y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); + y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x)))); } }; -struct Sigmoid final : public UnaryOpBase +struct Sigmoid { - __host__ __device__ constexpr Sigmoid() = default; - __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; - __host__ __device__ constexpr Sigmoid(Sigmoid&&) = default; - __host__ __device__ Sigmoid& operator=(const Sigmoid&) = default; - __host__ __device__ Sigmoid& operator=(Sigmoid&&) = default; - __host__ __device__ ~Sigmoid() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - constexpr float one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - constexpr double one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - constexpr int32_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - constexpr int8_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - constexpr half_t one = type_convert(1); - y = one / (one + ck::math::exp(-x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - constexpr float one = type_convert(1); - float x_f32 = ck::type_convert(x); - float y_f32 = one / (one + ck::math::exp(x_f32)); - y = ck::type_convert(y_f32); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + math::exp(-x)); + }; }; struct Silu @@ -726,52 +843,26 @@ struct Silu template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same_v || is_same_v || is_same_v || + static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Data type is not supported by this operation!"); constexpr T one = type_convert(1); - y = x * (one / (one + ck::math::exp(-x))); + y = x * (one / (one + math::exp(-x))); }; }; -struct TanH final : public UnaryOpBase +struct TanH { - __host__ __device__ constexpr TanH() = default; - __host__ __device__ constexpr TanH(const TanH&) = default; - __host__ __device__ constexpr TanH(TanH&&) = default; - __host__ __device__ TanH& operator=(const TanH&) = default; - __host__ __device__ TanH& operator=(TanH&&) = default; - __host__ __device__ ~TanH() = default; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - y = ck::math::tanh(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - y = ck::math::tanh(x); - } + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - y = ck::math::tanh(x); - } + y = math::tanh(x); + }; }; struct ACos @@ -780,11 +871,11 @@ struct ACos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acos(x); + y = math::acos(x); }; }; @@ -794,11 +885,11 @@ struct Neg __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::neg(x); + y = math::neg(x); }; }; @@ -808,11 +899,11 @@ struct ATan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atan(x); + y = math::atan(x); }; }; @@ -822,11 +913,11 @@ struct Sin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sin(x); + y = math::sin(x); }; }; @@ -836,11 +927,11 @@ struct ASinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asinh(x); + y = math::asinh(x); }; }; @@ -850,11 +941,11 @@ struct Cos __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cos(x); + y = cos(x); }; }; @@ -864,11 +955,11 @@ struct ACosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::acosh(x); + y = math::acosh(x); }; }; @@ -878,11 +969,11 @@ struct Tan __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::tan(x); + y = math::tan(x); }; }; @@ -892,11 +983,11 @@ struct ATanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::atanh(x); + y = math::atanh(x); }; }; @@ -906,11 +997,11 @@ struct SinH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::sinh(x); + y = math::sinh(x); }; }; @@ -920,11 +1011,11 @@ struct Ceil __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::ceil(x); + y = math::ceil(x); }; }; @@ -934,11 +1025,11 @@ struct Exp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::exp(x); + y = math::exp(x); }; }; @@ -948,11 +1039,11 @@ struct CosH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::cosh(x); + y = math::cosh(x); }; }; @@ -962,11 +1053,11 @@ struct Floor __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::floor(x); + y = math::floor(x); }; }; @@ -976,11 +1067,11 @@ struct Log __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::log(x); + y = math::log(x); }; }; @@ -990,11 +1081,11 @@ struct ASin __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::asin(x); + y = math::asin(x); }; }; @@ -1004,426 +1095,146 @@ struct Rcp __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, "Data type is not supported by this operation!"); - y = ck::math::rcp(x); + y = math::rcp(x); }; }; -struct Swish final : public UnaryOpBase +struct Swish { - __host__ __device__ constexpr Swish(const Swish&) = default; - __host__ __device__ constexpr Swish(Swish&&) = default; - __host__ __device__ ~Swish() = default; - - __host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {} - - __host__ __device__ float get_beta() const { return beta_; } - - const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } + Swish(float beta = 1.0f) : beta_(beta) {} template __host__ __device__ void operator()(Y& y, const X& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Data type is not supported by this operation!"); float bx = -beta_ * type_convert(x); - y = type_convert(x / (1.f + ck::math::exp(bx))); - } + y = type_convert(x / (1.f + math::exp(bx))); + }; + + const float beta_; }; -struct SoftRelu final : public UnaryOpBase +struct SoftRelu { - __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; - __host__ __device__ constexpr SoftRelu(SoftRelu&&) = default; - __host__ __device__ ~SoftRelu() = default; + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; - __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } - - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; } + const float alpha_; }; -struct Power final : public UnaryOpBase +struct Power { - __host__ __device__ constexpr Power(const Power&) = default; - __host__ __device__ constexpr Power(Power&&) = default; - __host__ __device__ ~Power() = default; + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; - __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) - : alpha_(alpha), beta_(beta), gamma_(gamma) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = math::pow(shifted_scaled_x, casted_gamma); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - - __host__ __device__ float get_gamma() const { return gamma_; } - const float alpha_; const float beta_; const float gamma_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - float casted_gamma = type_convert(gamma_); - - float shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - double casted_gamma = type_convert(gamma_); - - double shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - int32_t casted_gamma = type_convert(gamma_); - - int32_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - int8_t casted_gamma = type_convert(gamma_); - - int8_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - half_t casted_gamma = type_convert(gamma_); - - half_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - bhalf_t casted_gamma = type_convert(gamma_); - - bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x; - y = ck::math::pow(shifted_scaled_x, casted_gamma); - } }; -struct ClippedRelu final : public UnaryOpBase +struct ClippedRelu { - __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; - __host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default; - __host__ __device__ ~ClippedRelu() = default; + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; - __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) - : alpha_(alpha), beta_(beta) + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = math::min(casted_beta, math::max(casted_alpha, x)); } - - __host__ __device__ float get_alpha() const { return alpha_; } - - __host__ __device__ float get_beta() const { return beta_; } - const float alpha_; const float beta_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - float casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - double casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - int32_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - int8_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - half_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final - { - bhalf_t casted_alpha = type_convert(alpha_); - bhalf_t casted_beta = type_convert(beta_); - y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); - } }; -struct LeakyRelu final : public UnaryOpBase +struct LeakyRelu { - __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; - __host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default; - __host__ __device__ ~LeakyRelu() = default; - - __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x >= 0 ? x : x * casted_alpha; - } - - __host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y, - [[maybe_unused]] const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; } + const float alpha_; }; -struct Elu final : public UnaryOpBase +struct Elu { - __host__ __device__ constexpr Elu(const Elu&) = default; - __host__ __device__ constexpr Elu(Elu&&) = default; - __host__ __device__ ~Elu() = default; - - __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } - - const float alpha_; + Elu(float alpha = 1.f) : alpha_(alpha){}; - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); - } - - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * math::expm1(x); } + const float alpha_; }; -struct Logistic final : public UnaryOpBase +struct Logistic { - __host__ __device__ constexpr Logistic(const Logistic&) = default; - __host__ __device__ constexpr Logistic(Logistic&&) = default; - __host__ __device__ ~Logistic() = default; - - __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {} - - __host__ __device__ float get_alpha() const { return alpha_; } - - const float alpha_; - - __host__ __device__ inline void operator()(float& y, const float& x) const final - { - float casted_alpha = type_convert(alpha_); - constexpr float one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(double& y, const double& x) const final - { - double casted_alpha = type_convert(alpha_); - constexpr double one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final - { - int32_t casted_alpha = type_convert(alpha_); - constexpr int32_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final - { - int8_t casted_alpha = type_convert(alpha_); - constexpr int8_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } - - __host__ __device__ inline void operator()(half_t& y, const half_t& x) const final - { - half_t casted_alpha = type_convert(alpha_); - constexpr half_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); - } + Logistic(float alpha = 1.f) : alpha_(alpha){}; - __host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final + template + __host__ __device__ void operator()(T& y, const T& x) const { - bhalf_t casted_alpha = type_convert(alpha_); - constexpr bhalf_t one = type_convert(1); - y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); } + const float alpha_; }; struct ConvInvscale @@ -1488,7 +1299,7 @@ struct ConvScaleRelu __host__ __device__ void operator()(f8_t& e, const float& c) const { float x; - Relu{}(x, c * scale_in_ * scale_wei_); + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); e = type_convert(x * scale_out_); }; @@ -1505,10 +1316,10 @@ struct FastNumericArrayConverter }; template <> -struct FastNumericArrayConverter +struct FastNumericArrayConverter { using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1538,13 +1349,13 @@ struct FastNumericArrayConverter }; template -struct FastNumericArrayConverter +struct FastNumericArrayConverter { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); using InputArray = vector_type; - using OutputArray = vector_type; + using OutputArray = vector_type; __device__ static OutputArray convert(InputArray const& Input) { @@ -1553,7 +1364,7 @@ struct FastNumericArrayConverter OutputArray Output; using Vec_InputArray = vector_type; - using Vec_OutputArray = vector_type; + using Vec_OutputArray = vector_type; Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); @@ -1569,225 +1380,138 @@ struct FastNumericArrayConverter struct DynamicUnaryOp { - - DynamicUnaryOp& operator=(const DynamicUnaryOp& other) - { - if(this != &other) - { - unary_op_ptr_ = other.unary_op_ptr_; - unary_op_type_ = other.unary_op_type_; - } - return *this; - } - __host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp(const Swish& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } __host__ __device__ DynamicUnaryOp(const Swish&& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} { - unary_op_type_ = UnaryOpType::Swish; - beta = swish.get_beta(); } - __host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {} - __host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } + __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {} __host__ __device__ DynamicUnaryOp(const PassThrough&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const PassThrough&&) + : unary_op_type_(UnaryOpType::PassThrough) { - unary_op_type_ = UnaryOpType::PassThrough; } __host__ __device__ DynamicUnaryOp(const Logistic& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Logistic&& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} { - unary_op_type_ = UnaryOpType::Logistic; - alpha = logistic.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } + __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {} - __host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {} - __host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } + __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {} __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} { - unary_op_type_ = UnaryOpType::SoftRelu; - alpha = softrelu.get_alpha(); } - __host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {} - __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } + __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {} __host__ __device__ DynamicUnaryOp(const Power& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const Power&& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) { - unary_op_type_ = UnaryOpType::Power; - alpha = pow.get_alpha(); - beta = pow.get_beta(); - gamma = pow.get_gamma(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} { - unary_op_type_ = UnaryOpType::ClippedRelu; - alpha = clippedrelu.get_alpha(); - beta = clippedrelu.get_beta(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} { - unary_op_type_ = UnaryOpType::LeakyRelu; - alpha = leakyrelu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); } __host__ __device__ DynamicUnaryOp(const Elu&& elu) - { - unary_op_type_ = UnaryOpType::Elu; - alpha = elu.get_alpha(); - } - - __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) - : unary_op_type_(dynamic_op.unary_op_type_), - unary_op_ptr_(dynamic_op.unary_op_ptr_), - alpha(dynamic_op.alpha), - beta(dynamic_op.beta), - gamma(dynamic_op.gamma) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} { } - __host__ __device__ ~DynamicUnaryOp() - { - switch(unary_op_type_) - { - case(UnaryOpType::Swish): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Sigmoid): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::PassThrough): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Logistic): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::TanH): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Relu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::SoftRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::UnaryAbs): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Power): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::ClippedRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::LeakyRelu): delete static_cast(unary_op_ptr_); break; - case(UnaryOpType::Elu): delete static_cast(unary_op_ptr_); break; - - default: break; - } - } - - __device__ void InitUnaryOpPtrOnDevice() - { - switch(unary_op_type_) - { - case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break; - case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break; - case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break; - case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break; - case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break; - case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break; - case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break; - case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break; - case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break; - case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break; - case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break; - case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break; - - default: unary_op_ptr_ = nullptr; break; - } - } + __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default; - template - __device__ void operator()(Y& y, const X& x) const - { - isSupported(); - unary_op_ptr_->operator()(y, x); - } + __host__ __device__ ~DynamicUnaryOp() {} template - __host__ void operator()(Y& y, const X& x) const + __host__ __device__ void operator()(Y& y, const X& x) const { - isSupported(); switch(unary_op_type_) { - case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; - case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; - case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; - case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; - case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; - case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; - case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; - case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; - case(UnaryOpType::Power): Power{}.operator()(y, x); break; - case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; - case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; - case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; + case(UnaryOpType::Swish): swish_(y, x); break; + case(UnaryOpType::Sigmoid): sigmoid_(y, x); break; + case(UnaryOpType::PassThrough): pass_through_(y, x); break; + case(UnaryOpType::Logistic): logistic_(y, x); break; + case(UnaryOpType::TanH): tanh_(y, x); break; + case(UnaryOpType::Relu): relu_(y, x); break; + case(UnaryOpType::SoftRelu): soft_relu_(y, x); break; + case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break; + case(UnaryOpType::Power): power_(y, x); break; + case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break; + case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break; + case(UnaryOpType::Elu): elu_(y, x); break; default: break; } } - template - __device__ __host__ constexpr void isSupported() const + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { - - static_assert(std::is_same::value, "X and Y must be of the same type"); - - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Data type is not supported by this operation!"); + float y_float; + float x_float = type_convert(x); + this->operator()(y_float, x_float); + y = type_convert(y_float); } private: @@ -1809,12 +1533,20 @@ struct DynamicUnaryOp public: UnaryOpType unary_op_type_; - UnaryOpBase* unary_op_ptr_ = nullptr; - float alpha; - float beta; - float gamma; + + Swish swish_; + Sigmoid sigmoid_; + PassThrough pass_through_; + Logistic logistic_; + TanH tanh_; + Relu relu_; + SoftRelu soft_relu_; + UnaryAbs unary_abs_; + Power power_; + ClippedRelu clipped_relu_; + LeakyRelu leaky_relu_; + Elu elu_; }; -#pragma clang diagnostic pop } // namespace element_wise } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 56c37b1b7240de0e85fe2ddd6faa7264deaaa32e..2bc9ef87acfcab6ccd54e4100e69dceb2b2a50d8 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,14 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/math.hpp" #include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#ifndef CK_CODE_GEN_RTC #include #include +#endif namespace ck { @@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit // Create 3D grid const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return std::make_tuple(N0, M0, k_split); + return make_tuple(N0, M0, k_split); } template @@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t best_sk_score = - std::numeric_limits::max(); // we need to find the smallest sk iters + NumericLimits::Max(); // we need to find the smallest sk iters for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; tentative_sk_blocks++) { diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 206ea00b9d0a94dcd5f12c02948d8f84f9223fb5..f4d0989088c6c59c2fff2dd238e1bdc5a38eeb3c 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 9469fa7bc7b27351f393caed43a9bc12a2b8780d..55e254e015d1303461d859cff30a774639535608 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // acc1[m][o] += acc[m][n] * B1[n][o] // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, @@ -607,6 +615,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 42f7c2a33fbd77ec58ed16889860924b1f8f8dbb..fd16927cc1e3f92268299debffb5935cf973a70c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -361,10 +361,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const auto M = d0_grid_desc_m_n.GetLength(I0); const auto N = d0_grid_desc_m_n.GetLength(I1); - constexpr auto mfma = - MfmaSelector::selected_mfma; - constexpr auto N3 = mfma.num_groups_per_blk; - constexpr auto N5 = mfma.group_size; + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + math::lcm(A0K1, B0K1) <= 4) + ? true + : false; + constexpr auto mfma = MfmaSelector::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( d0_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple( @@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // acc1[m][o] += acc[m][n] * B1[n][o] // sanity check - constexpr index_t KPack = math::max( - math::lcm(A0K1, B0K1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 <= 4) + ? true + : false; + constexpr index_t KPack = + math::max(lcm_A0K1_B0K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, @@ -856,11 +874,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr index_t Gemm1KPack = math::max( - math::lcm( - MfmaSelector::selected_mfma.group_size, - B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index bc76d4cc4fb9e905cdb1547273866b66575ff334..1f7458e68f8505f5ae20163e9ea69289e33f73b7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -343,10 +343,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const auto M = d0_grid_desc_m_n.GetLength(I0); const auto N = d0_grid_desc_m_n.GetLength(I1); - constexpr auto mfma = MfmaSelector::selected_mfma; - constexpr auto N3 = mfma.num_groups_per_blk; - constexpr auto N4 = mfma.num_input_blks; - constexpr auto N5 = mfma.group_size; + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + math::lcm(AK1, BK1) <= 4) + ? true + : false; + constexpr auto mfma = + MfmaSelector::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N4 = mfma.num_input_blks; + constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( d0_grid_desc_m_n, make_tuple(make_unmerge_transform( @@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // acc1[m][o] += acc[m][n] * B1[n][o] // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, @@ -773,6 +787,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index afb2ad2e760396c200930254f871ff13e032a91a..f7746b470f1db4030adff0f700a494d23d3285ec 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // acc1[m][o] += acc[m][n] * B1[n][o] // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, @@ -628,6 +636,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 35176591c1cfdd4ed0037b5baf24965660e4a0ee..8b3f51b9b0d1140ddf8365d89cc07b4360217c86 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 60c02d64e11cc8be94de6bd6aa8041f2d30b55a4..344656b13f6b5b77ab9ae12e8d02c763ff1d0ca8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = @@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeAsGridDescriptor_M_K(const std::array& MRaws, - const std::array& KRaws, - const std::array& AsStride) + __host__ __device__ static auto MakeAsGridDescriptor_M_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& KRaws, + const ck::Array& AsStride +#else + const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array& NRaws, - const std::array& KRaws, - const std::array& BsStride) + __host__ __device__ static auto MakeBsGridDescriptor_N_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& NRaws, + const ck::Array& KRaws, + const ck::Array& BsStride +#else + const std::array& NRaws, + const std::array& KRaws, + const std::array& BsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle } template - __host__ __device__ static auto - MakeDsGridDescriptor_M_N(const std::array& MRaws, - const std::array& NRaws, - const std::array& DsStride) + __host__ __device__ static auto MakeDsGridDescriptor_M_N( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride +#else + const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride +#endif + ) { return generate_tuple( [&](auto i) { @@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle const index_t M, const index_t N, const index_t K, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideAs, + const ck::Array StrideBs, + const ck::Array StrideDs, +#else const std::array StrideAs, const std::array StrideBs, const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 5c9f40b51a5ba28810352cf9dcf6d6e2dca1a404..60ee78528df229d35450db3d25e10420b19db0f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index e6085fad8c8b13c541beb373114a58c22f6c1a60..eb1eb533d7aa38553ce945bc3452249f9e6d0a50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; using BComputeDataType = @@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } +#ifdef CK_CODE_GEN_RTC + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride) +#else template __host__ __device__ static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, const std::array& NRaws, const std::array& DsStride) +#endif { return generate_tuple( [&](auto i) { @@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t K, const index_t StrideA, const index_t StrideB, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideDs, +#else const std::array StrideDs, +#endif const index_t StrideE, const Block2ETileMap& block_2_etile_map) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index cd36b9e51ae6db6c74f491b253ae7099a259d747..b4c5d004c49808273fa0b2f970eb94b18222165f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index ae93af192e6768092b707b9bcc75138453b7c873..d1d97da5b0dc57a8fc9e19b57a818480a14676bf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 44cbbcd04967064ce36efb7e826d0f2714d69295..9dad66913aec8445c1d99d6d5254517bbd040cdd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef CK_CODE_GEN_RTC #include #include +#endif #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" @@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector() } else { +#ifndef CK_CODE_GEN_RTC std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; +#endif } } } // namespace ck +#ifndef CK_CODE_GEN_RTC inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) { switch(p) @@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) } return os; } +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 0e5777e561476da30681b67f6d0d85a152a683f5..7105fa70124060b6a4363c97f5c3704a7059647e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 0078660556418671d685c963ac3d25a8b82d087c..3429c20e73bdfabe0be21e62d211d958ed64122b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index caf8f040f4a78232386a1f0c3d0be02b4dc63295..d7c87a170c9f7ef5d5467be0c10de131d85306cc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp index d2a06ba9afe5b89be9b19c3bb1183d7b21c219fb..08d9386d72a4d57ed536e2d17889be47486eb2da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp old mode 100755 new mode 100644 index 6ef35da485bcb7d28fb685da37e1aa45b13066e3..e04f24c9890666d6b8e0ec0009dd7bd88ec4d698 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index db9625c6e6032114b87209bfc47048f4c6d01b86..af91721c8ae3690605da5e1b94720bcc5ad361d1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index da9ee4d77f2735e7629bf5b412c2973bcf92dad8..7a61000eb12fdbe9e856b04005ef64d4aeeba604 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -127,7 +127,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct GridwiseGemm_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -145,12 +147,33 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -319,6 +342,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 using GemmSpecialization = tensor_operation::device::GemmSpecialization; + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) { @@ -373,15 +400,39 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } } } @@ -572,7 +623,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead; + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; } else if constexpr(is_same_v) { @@ -585,7 +636,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 } else if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead; + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } } if(blockIdx.z < static_cast(karg.KBatch - 1)) @@ -627,9 +686,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 // in some cases. else if constexpr(is_same::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( AK0Number * Number{}, Number{}, AK1Number), @@ -765,10 +823,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( BK0Number * Number{}, Number{}, BK1Number), @@ -950,8 +1006,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } @@ -1348,8 +1404,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1820,16 +1877,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_buf_pong = make_dynamic_buffer( static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2e62110416eed0eb82feb4b892ea7ba0b8b5f6ee --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -0,0 +1,2217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + using BScaleType = ck::half_t; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + } + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // b scale + // static_assert(KPerBlock <= ScaleBlockK); + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + static constexpr auto ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // B scale + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + const index_t ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 7f815de1f9d3ffc145664e8781d6f7af20cbcde1..0a62464cc2d1dd5335585a0188ba331934df74e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 8675a9242acf0d573cc0316f3f89b95b654b38af..6a4b1cc14b83328d46e03f85351c8bcba99e6609 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle else if(TileMathThreadGroup::IsBelong()) { // branch early for math wave - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 5617f67f8be20f9d775f157c6e24804c4a906241..b41e747a3aa562cc7c1bd4295d47b63c0593db4c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // when mfma if fixed, remove this section and update // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, // throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatAAdjusted = conditional_t, ck::bhalf_t, ComputeTypeA>; using FloatBAdjusted = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 4f3caff24893fc10239bce5d0d3290a0590ee5ea..5c3d9b7ba4bc0f8899b0ee6eeabbbaa89b503c16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX +#if CK_GFX90A_DENORM_WORKAROUND using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 15c64f2e474272ac05f5a901e9bb8899103875d7..7db87986956efaa40344df04c6d103296aae4301 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; constexpr index_t k_pack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + lcm_AK1_BK1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::type = false> struct ThreadwiseTensorSliceTransfer_v2 { - static_assert((InvalidElementAsNaN && !std::is_integral::value) || + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || (!InvalidElementAsNaN), "Filling invalid element as NaN is only for floating point types"); @@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { @@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } } template src_tmp_vector; + vector_type_maker_t src_tmp_vector; using src_vector_t = typename decltype(src_tmp_vector)::type; @@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4 if constexpr(SrcBuffer::IsDynamicBuffer()) { src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); } else if constexpr(SrcBuffer::IsStaticBuffer()) { @@ -1133,9 +1146,236 @@ struct ThreadwiseTensorSliceTransfer_v4 }); } - if constexpr(is_same, f8_t>::value && - is_same, half_t>::value && - SrcScalarPerVector % 2 == 0) + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + }); + } + + // Fuse scale + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstData& scale, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } + + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + vector_type scale_vector; + scale_vector.template AsType()(Number<0>{}) = scale; + scale_vector.template AsType()(Number<1>{}) = scale; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + using scale_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::DequantPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i], + scale_vector.template AsType()[Number<0>{}]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) { // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // DstData) @@ -1304,7 +1544,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; -// Specilized for WMMA-Navi3 +// Specialized for gfx11 // A single Wave32 is composed by double row // Data exchange allowed between these two rows // This RowLane Dst buf will be filled from two Src buf @@ -1439,7 +1679,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow ElementwiseOperation element_op_{}; }; -// Specilized for WMMA-Navi4 +// Specialized for gfx12 template {}; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr auto SrcScalarPerVector = Number{}; + static constexpr auto DstScalarPerVector = Number{}; __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( const SrcDesc& src_desc, @@ -67,6 +90,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_element_op_(src_element_op), dst_element_op_(dst_element_op) { + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose"); + } } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -95,10 +129,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, + + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); constexpr auto src_dim_access_order = SrcDimAccessOrder{}; @@ -176,12 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, is_src_valid); - using src_vector_type = vector_type_maker_t; - using src_vector_t = typename src_vector_type::type; - - auto src_vector_container = - src_vector_type{src_buf.template Get(src_coord_.GetOffset(), true)}; - using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; dst_vector_type op_r_v; @@ -192,17 +221,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 if constexpr(decltype(src_element_op_)::is_pack8_invocable) return math::min(8, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack4_invocable) return math::min(4, SrcScalarPerVector); } - if constexpr(is_detected::value) + else if constexpr(is_detected::value) { if constexpr(decltype(src_element_op_)::is_pack2_invocable) return math::min(2, SrcScalarPerVector); } - return 1; + else + { + return 1; + } }; constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); @@ -210,11 +244,63 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using src_elem_op_vec_t = typename vector_type::type; using dst_elem_op_vec_t = typename vector_type::type; - static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if needed - src_element_op_(op_r_v.template AsType()(idx), - src_vector_container.template AsType()[idx]); - }); + using VectorSizeLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + using VectorOffsetsLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + static_for<0, tuple_element_t::Size(), 1>{}( + [&](auto v_idx) { + constexpr auto VectorLoadSize = + tuple_element_t::At(v_idx); + constexpr auto LoadOffset = + tuple_element_t::At(v_idx); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + src_vector_container src_vector = + src_vector_container{src_buf.template Get( + src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; + + static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if + // needed + src_element_op_( + op_r_v.template AsType()(idx + LoadOffset), + src_vector.template AsType()[idx]); + }); + }); // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) @@ -289,10 +375,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else - // OOB Check constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -363,6 +448,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { + static_assert(!is_same_v, pk_i4_t>, + "in-register transpose is not supported for pk_i4_t"); // each transpose does // DstScalarPerVector # of src vectors in src_thread_scratch_ // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ @@ -423,7 +510,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } else { - static_ford{}([&](auto idx) { + constexpr auto packed_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; + + static_ford{}([&](auto idx) { dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } @@ -451,7 +543,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -539,13 +631,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // apply DstElementwiseOperation dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); - - dst_vector_container.template AsType()(i) = dst_v; }); // copy data from dst_vector_container to dst_buf dst_buf.template Set( - dst_coord_.GetOffset(), + dst_coord_.GetOffset() / PackedSize, is_dst_valid, dst_vector_container.template AsType()[I0]); @@ -605,7 +695,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -663,7 +753,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -749,7 +839,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -798,7 +888,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -809,7 +899,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { // 1st stage of transforms constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index b435a2a1293c87cd70bee4130e5a15f60bec6bda..1abae56be4d3625e5eaaf978a7350ca011adfe37 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -307,7 +307,7 @@ struct wmma_type{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // * Fixed for gfx11, Will be wave mode dependent on gfx12 // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; // * num_acc_vgprs_per_wave alone M direction diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 24fac91e22a01c7fbf7744ab48e7256d7e2ef900..4f20487b9b7e531e6b284162541934bbebebcfac 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -37,7 +37,17 @@ enum struct MfmaInstr mfma_f32_32x32x16f8bf8, mfma_f32_16x16x32f8bf8, mfma_f32_32x32x16bf8f8, - mfma_f32_16x16x32bf8f8 + mfma_f32_16x16x32bf8f8, + mfma_f32_32x32x16f16, + mfma_f32_16x16x32f16, + mfma_f32_32x32x16bf16, + mfma_f32_16x16x32bf16, + mfma_i32_32x32x32i8, + mfma_i32_16x16x64i8, + mfma_f32_32x32x64f8f6f4, + mfma_f32_16x16x128f8f6f4, + mfma_scale_f32_32x32x64f8f6f4, + mfma_scale_f32_16x16x128f8f6f4 }; template @@ -198,6 +208,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -264,6 +318,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -286,6 +362,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -440,6 +538,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x64i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -638,16 +780,115 @@ struct mfma_type } }; +// TODO: fix mfma...f8f6f4 instructions +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + template + typename additional_type = base_type, + bool is_single_rate_mfma = false> struct MfmaSelector { template + typename additional_type_ = base_type_, + bool is_single_rate_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -711,13 +952,32 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16f16; +#else + return MfmaInstr::mfma_f32_32x32x8f16; +#endif + } + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32f16; +#else + return MfmaInstr::mfma_f32_16x16x16f16; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } @@ -741,7 +1001,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -751,7 +1023,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -760,7 +1044,18 @@ struct MfmaSelector #endif } -#if defined(CK_USE_AMD_MFMA_GFX940) +#if defined(__gfx950__) + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x32i8; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x64i8; + } +#elif defined(__gfx942__) template <> constexpr auto GetMfma() { @@ -832,8 +1127,8 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = - mfma_type()>{}; + static constexpr auto selected_mfma = mfma_type< + GetMfma()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1135,7 +1430,13 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942- + static constexpr auto + mfma = MfmaSelector < base_type, + MPerXdlops, NPerXdlops, additional_type, + ((is_same::value || is_same::value) && KPack <= 4) + ? true + : false > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index b91b12ad52380a82ac6213cea27016700aca1461..3db94deccb465da766483521bf7bef7ca332a02c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,10 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -148,8 +147,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -201,11 +200,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I0]}, ZYX_{X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -219,8 +222,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -272,11 +275,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I1]}, ZYX_{Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -290,8 +297,8 @@ struct TransformConvFwdToGemm template ::type = false> + index_t NDim = NDimSpatial, + typename ck::enable_if::type = false> __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& b_g_k_c_xs_lengths, @@ -343,11 +350,15 @@ struct TransformConvFwdToGemm InRightPadW_{input_right_pads[I2]}, ZYX_{Z_ * Y_ * X_} { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else static_assert(is_same_v> || is_same_v>); static_assert(is_same_v> || is_same_v>); - +#endif if constexpr(SplitN) { N_ = GetSplitedNSize( @@ -478,11 +489,11 @@ struct TransformConvFwdToGemm // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { if constexpr(ConvForwardSpecialization == @@ -691,11 +702,11 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeADescriptor_M_K() const { @@ -932,7 +943,7 @@ struct TransformConvFwdToGemm } template || is_same_v || is_same_v), @@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter3x3) { using FilterSizeNumType = - std::conditional_t, - std::conditional_t, Number<27>>>; + ck::conditional_t, + ck::conditional_t, Number<27>>>; if constexpr(NumGroupsToMerge == 1) { @@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm template < typename BLayout, - typename std::enable_if || - is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v, - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> __host__ __device__ auto MakeBDescriptor_N_K() const { const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( @@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm return wei_gemmn_gemmk_desc; } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), make_tuple(I0, KStrideTensorC_)); } - template ), - bool>::type = false> + typename ck::enable_if), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), @@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm } template || - is_same_v || - is_same_v), - bool>::type = false> + index_t NDimSp = NDimSpatial, + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Wo_; @@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm template || - is_same_v || - is_same_v), - bool>::type = false> + typename ck::enable_if || + is_same_v || + is_same_v), + bool>::type = false> __host__ __device__ auto MakeCDescriptor_M_N() const { const IndexType NDoHoWo = N_ * Ho_ * Wo_; @@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm template || is_same_v || is_same_v), diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 5367c3d72057848f86eb30b2f05415893d76780a..328e37d00971eee2ee50270a320673f0e1f88a9d 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -429,7 +429,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; @@ -580,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -1020,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; static_assert(bytes_per_thread == dword_bytes); +#ifndef CK_CODE_GEN_RTC const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); +#else + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); +#endif const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; +#ifndef CK_CODE_GEN_RTC auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#else + auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#endif asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), @@ -1037,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = +#ifndef CK_CODE_GEN_RTC reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); +#else + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#endif llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index e9174904c9fedbea99dfb43d6b25947d8ace9401..42b784d303766ccf3e3dd1ba0d7ee296f34f3d85 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/type.hpp" @@ -18,39 +20,25 @@ #define CK_USE_OCP_FP8 0 #endif -namespace { -// https://en.cppreference.com/w/cpp/types/conditional -template -struct conditional -{ - using type = T; -}; -template -struct conditional -{ - using type = F; -}; -} // namespace - -namespace ck { - -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); - #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ - defined(__gfx1201__)) && \ + defined(__gfx1201__) || defined(__gfx950__)) && \ __HIP_DEVICE_COMPILE__ #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 #endif +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + typedef unsigned char fp8_storage_t; /** @@ -205,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) } } - typename conditional< + typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type retval; + typename std::conditional::type>::type + retval; if constexpr(we == 5 && is_half && !is_fnuz) { @@ -301,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); } } - #endif } // namespace fp8_impl @@ -376,7 +364,7 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -390,7 +378,7 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( @@ -424,9 +412,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) } template || std::is_same_v || - std::is_same_v || std::is_same_v, - bool> = true> + ck::enable_if_t || is_same_v || + is_same_v || is_same_v, + bool> = true> __host__ __device__ static inline constexpr bool fp8_is_inf(T) { return false; @@ -551,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); - using T_bitwise = typename conditional< + using T_bitwise = typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type; + typename std::conditional::type>::type; T_bitwise x_bitwise = bit_cast(_x); unsigned long long x{x_bitwise}; @@ -823,7 +811,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) if constexpr(stochastic_rounding) { constexpr int seed = 1254739; - rng = prand_generator(reinterpret_cast(&f), f); +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif } return cast_to_f8_from_f32( f, rng); @@ -839,7 +831,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) if constexpr(stochastic_rounding) { constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif } if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 5dc67a5aded4af289d0240394f720af62e699eb4..113f3af4ae51adbb17512f10a6cdec55d535c40a 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -4,13 +4,34 @@ #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP -#include "data_type.hpp" #include "c_style_pointer_cast.hpp" +#include "data_type.hpp" // TODO: deprecate all amd_assembly_outer_product_xxx namespace ck { +inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) +{ + int c; + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d)); + return c; +} + +inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) +{ + half2_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +} + +inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) +{ + half2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index d6e1eab314e30184c669abe88f5a4cf7f5ea90c4..128c8e9a2c50ba9dc9d123c1e3dc4c036f39e872 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,10 +7,12 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" +#ifndef CK_CODE_GEN_RTC #include #include #include #include +#endif namespace ck { namespace detail { @@ -37,7 +39,7 @@ struct get_carrier<3> { using value_type = uint32_t; - std::array bytes; + Array bytes; static_assert(sizeof(bytes) <= sizeof(value_type)); // replacement of host std::copy_n() @@ -61,22 +63,22 @@ struct get_carrier<3> // method to trigger template substitution failure __device__ carrier(const carrier& other) noexcept { - copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + copy_n(other.bytes.begin(), bytes.Size(), bytes.begin()); } public: __device__ carrier& operator=(value_type value) noexcept { - copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + copy_n(reinterpret_cast(&value), bytes.Size(), bytes.begin()); return *this; } __device__ operator value_type() const noexcept { - std::byte result[sizeof(value_type)]; + ck::byte result[sizeof(value_type)]; - copy_n(bytes.begin(), bytes.size(), result); + copy_n(bytes.begin(), bytes.Size(), result); return *reinterpret_cast(result); } @@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) { constexpr unsigned object_size = sizeof(int64_t); constexpr unsigned second_part_offset = object_size / 2; - auto* const from_obj = reinterpret_cast(&value); - alignas(int64_t) std::byte to_obj[object_size]; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) ck::byte to_obj[object_size]; using Sgpr = uint32_t; @@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) return *reinterpret_cast(to_obj); } -template < - typename Object, - typename = std::enable_if_t && std::is_trivially_copyable_v>> +template && ck::is_trivially_copyable_v>> __device__ auto amd_wave_read_first_lane(const Object& obj) { using Size = unsigned; constexpr Size SgprSize = 4; constexpr Size ObjectSize = sizeof(Object); - auto* const from_obj = reinterpret_cast(&obj); - alignas(Object) std::byte to_obj[ObjectSize]; + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) ck::byte to_obj[ObjectSize]; constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 5a7030cca766ed4bfdea91fb5f859211fd0d12ce..b125e3adf63a5b50a63a3f8f62c2417eacc8b2dd 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -5,7 +5,7 @@ namespace ck { // Define the common macro for MI300 models -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif @@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> } }; +template +struct intrin_mfma_f32_32x32x16f16; + +template <> +struct intrin_mfma_f32_32x32x16f16<32, 32> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32f16; + +template <> +struct intrin_mfma_f32_16x16x32f16<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8f16; @@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> }; // bfp16 +template +struct intrin_mfma_f32_32x32x16bf16; + +template <> +struct intrin_mfma_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32bf16; + +template <> +struct intrin_mfma_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8bf16_1k; @@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x32i8; + +template <> +struct intrin_mfma_i32_32x32x32i8<32, 32> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_16x16x64i8; + +template <> +struct intrin_mfma_i32_16x16x64i8<16, 16> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_i32_32x32x16i8; @@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +template +struct intrin_mfma_f32_32x32x64f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6, +/// and f4 data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4; + +template <> +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4; + +template <> +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x128f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4 +/// data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + template struct intrin_mfma_f32_32x32x16f8f8; diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 5366c56a9dfa7275ecca75d41daaf1a5cba6333d..2afad00d497f840af8221ef66ae8ec24de7e23ec 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP @@ -38,6 +38,8 @@ struct Array } __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } + __host__ __device__ constexpr TData* begin() { return &mData[0]; } + __host__ __device__ constexpr TData* end() { return &mData[NSize]; } }; // empty Array @@ -54,7 +56,7 @@ template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { using data_type = remove_cvref_t; - return Array{std::forward(x), std::forward(xs)...}; + return Array{ck::forward(x), ck::forward(xs)...}; } // make empty array diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 96dd34010e03a6165f6f0a792ab7c19b7c7eec97..574be5c4adaf918d084f936053ec8649f77229c1 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -103,14 +103,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst KPerXDL); printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " - "%d, %d\n C MFMA inst: %d\n", + "%d, %d\n C MFMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d/ %d\n", A_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num, A_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num, A_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num, - C_MFMA_Inst_Num); + C_MFMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); } }; diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index 9c7b954565d386a8fdecd21052b102e750ab7102..bd0ca42ecddb736aa002151969bf83bf093691ff 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP @@ -326,14 +326,14 @@ template __host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) { return unpack2( - [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); + [&](auto&&... zs) { return make_array(ck::forward(zs)...); }, ax, ay); } template __host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) { return unpack2( - [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); + [&](auto&&... zs) { return make_tuple(ck::forward(zs)...); }, tx, ty); } template diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a7dc071bc21e85b7cd732dc08db5bd9ca869394b..f90fcf67915e6b75ae31b3d10a38b8f8dac23164 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,16 +1,328 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" - +#ifdef CK_CODE_GEN_RTC +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using float_t = float; +#endif namespace ck { +#ifdef CK_CODE_GEN_RTC +using byte = unsigned char; +#else +using std::byte; +#endif + using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); +using f4_t = unsigned _BitInt(4); +using f6_t = _BitInt(6); // e2m3 format +using bf6_t = unsigned _BitInt(6); // e3m2 format + +struct f4x2_pk_t +{ + using type = uint8_t; + type data; + f4x2_pk_t() : data{type{}} {} + f4x2_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline type unpack(Number) const + { + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 0) + return data & 0b00001111; + else + return (data >> 4); + } + + __host__ __device__ inline type pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } +}; + +struct f6x16_pk_t +{ + // store 16 elements of f6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + f6x16_pk_t() : data{type{}} {} + f6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct f6x32_pk_t +{ + // store 32 elements of f6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + f6x32_pk_t() : data{type{}} {} + f6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x16_pk_t +{ + // store 16 elements of bf6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + bf6x16_pk_t() : data{type{}} {} + bf6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x32_pk_t +{ + // store 32 elements of bf6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + bf6x32_pk_t() : data{type{}} {} + bf6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +// custom data type - pack int4 data +struct pk_i4_t +{ + using type = int8_t; + type data; + __host__ __device__ constexpr pk_i4_t() : data{type{}} {} + __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} +}; inline constexpr auto next_pow2(uint32_t x) { @@ -19,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool +// native types: bool, f4_t, f6_t, bf6_t template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value; } // vector_type @@ -165,6 +478,13 @@ struct scalar_type }; #endif +template <> +struct scalar_type +{ + using type = pk_i4_t; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { @@ -201,7 +521,7 @@ struct scalar_type }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -237,7 +557,7 @@ struct vector_type()>> __device__ int static err = 0; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -297,20 +617,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d3_t __attribute__((ext_vector_type(3))); - using type = d4_t; + using type = d3_t; union { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; + d3_t d3_; + StaticallyIndexedArray d1x3_; + StaticallyIndexedArray d2x1_; + StaticallyIndexedArray d3x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -320,20 +640,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x4_; + return data_.d1x3_; } else if constexpr(is_same::value) { - return data_.d2x2_; + return data_.d2x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d4x1_; + return data_.d3x1_; } else { @@ -344,20 +664,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x4_; + return data_.d1x3_; } else if constexpr(is_same::value) { - return data_.d2x2_; + return data_.d2x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d4x1_; + return data_.d3x1_; } else { @@ -367,22 +687,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - using type = d8_t; + using type = d4_t; union { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -392,25 +710,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x8_; + return data_.d1x4_; } else if constexpr(is_same::value) { - return data_.d2x4_; + return data_.d2x2_; } else if constexpr(is_same::value) { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; + return data_.d4x1_; } else { @@ -421,25 +734,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x8_; + return data_.d1x4_; } else if constexpr(is_same::value) { - return data_.d2x4_; + return data_.d2x2_; } else if constexpr(is_same::value) { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; + return data_.d4x1_; } else { @@ -449,24 +757,20 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d5_t __attribute__((ext_vector_type(5))); - using type = d16_t; + using type = d5_t; union { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; + d5_t d5_; + StaticallyIndexedArray d1x5_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d5x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -476,30 +780,20 @@ struct vector_type()>> template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; + return data_.d1x5_; } else if constexpr(is_same::value) { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d16x1_; + return data_.d5x1_; } else { @@ -510,30 +804,20 @@ struct vector_type()>> template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; + return data_.d1x5_; } else if constexpr(is_same::value) { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d16x1_; + return data_.d5x1_; } else { @@ -543,26 +827,22 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d7_t __attribute__((ext_vector_type(7))); - using type = d32_t; + using type = d7_t; union { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; + d7_t d7_; + StaticallyIndexedArray d1x7_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d7x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -573,33 +853,24 @@ struct vector_type()>> __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x32_; + return data_.d1x7_; } else if constexpr(is_same::value) { - return data_.d2x16_; + return data_.d2x3_; } else if constexpr(is_same::value) { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d32x1_; + return data_.d7x1_; } else { @@ -611,33 +882,24 @@ struct vector_type()>> __host__ __device__ constexpr auto& AsType() { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x32_; + return data_.d1x7_; } else if constexpr(is_same::value) { - return data_.d2x16_; + return data_.d2x3_; } else if constexpr(is_same::value) { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; + return data_.d4x1_; } - else if constexpr(is_same::value) + else if constexpr(is_same::value) { - return data_.d32x1_; + return data_.d7x1_; } else { @@ -647,28 +909,22 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - using type = d64_t; + using type = d8_t; union { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; } data_; __host__ __device__ constexpr vector_type() : data_{type{0}} {} @@ -679,34 +935,402 @@ struct vector_type()>> __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { - return data_.d1x64_; + return data_.d1x8_; } else if constexpr(is_same::value) { - return data_.d2x32_; + return data_.d2x4_; } else if constexpr(is_same::value) { - return data_.d4x16_; + return data_.d4x2_; } else if constexpr(is_same::value) { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; + return data_.d8x1_; } - else if constexpr(is_same::value) + else { - return data_.d32x2_; + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d13_t __attribute__((ext_vector_type(13))); + + using type = d13_t; + + union + { + d13_t d13_; + StaticallyIndexedArray d1x13_; + StaticallyIndexedArray d4x3_; + StaticallyIndexedArray d8x1_; + StaticallyIndexedArray d13x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; } else if constexpr(is_same::value) { @@ -763,7 +1387,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -889,7 +1513,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -1038,17 +1662,48 @@ struct nnvb_data_t_selector { using type = f8_ocp_t::data_type; }; + template <> struct nnvb_data_t_selector { using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + template struct non_native_vector_base< T, N, - std::enable_if_t> + ck::enable_if_t> { using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); @@ -1119,27 +1774,84 @@ struct non_native_vector_base< } } - template - __host__ __device__ constexpr auto& AsType() + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - if constexpr(is_same_v) + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) { - return data_.dxN; + return data_.dxN[Number<0>{}]; } - else if constexpr(is_same_v) + else { - return data_.dTxN; + return data_.dxN; // XXX this should cause an error } - else if constexpr(is_same_v) + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) { - return data_.dNx1; + return data_.dTxN[Number<0>{}]; } else { - return err; + return data_.dTxN; // XXX this should cause an error } } }; @@ -1163,9 +1875,17 @@ struct scalar_type> static constexpr index_t vector_size = N; }; +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + // non-native vector_type implementation template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1216,7 +1936,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1279,7 +1999,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1352,7 +2072,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1437,7 +2157,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d1_nnv_t = non_native_vector_base; @@ -1532,7 +2252,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1636,7 +2356,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1751,140 +2471,371 @@ struct vector_type()>> } }; -using int64_t = long; +using int64_t = long; + +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; +using bhalf64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; + +// bf8 +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + +#ifdef CK_CODE_GEN_RTC +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F8000000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + + __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } -// fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; -// bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } -// f8 -using f8x2_fnuz_t = typename vector_type::type; -using f8x4_fnuz_t = typename vector_type::type; -using f8x8_fnuz_t = typename vector_type::type; -using f8x16_fnuz_t = typename vector_type::type; -using f8x32_fnuz_t = typename vector_type::type; -using f8x64_fnuz_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } -// bf8 -using bf8x2_fnuz_t = typename vector_type::type; -using bf8x4_fnuz_t = typename vector_type::type; -using bf8x8_fnuz_t = typename vector_type::type; -using bf8x16_fnuz_t = typename vector_type::type; -using bf8x32_fnuz_t = typename vector_type::type; -using bf8x64_fnuz_t = typename vector_type::type; + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; -// f8 -using f8x2_ocp_t = typename vector_type::type; -using f8x4_ocp_t = typename vector_type::type; -using f8x8_ocp_t = typename vector_type::type; -using f8x16_ocp_t = typename vector_type::type; -using f8x32_ocp_t = typename vector_type::type; -using f8x64_ocp_t = typename vector_type::type; +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 -// bf8 -using bf8x2_ocp_t = typename vector_type::type; -using bf8x4_ocp_t = typename vector_type::type; -using bf8x8_ocp_t = typename vector_type::type; -using bf8x16_ocp_t = typename vector_type::type; -using bf8x32_ocp_t = typename vector_type::type; -using bf8x64_ocp_t = typename vector_type::type; + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } -#if CK_FP8_TYPE_OCP -// f8 -using f8x2_t = f8x2_ocp_t; -using f8x4_t = f8x4_ocp_t; -using f8x8_t = f8x8_ocp_t; -using f8x16_t = f8x16_ocp_t; -using f8x32_t = f8x32_ocp_t; -using f8x64_t = f8x64_ocp_t; + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } -// bf8 -using bf8x2_t = bf8x2_ocp_t; -using bf8x4_t = bf8x4_ocp_t; -using bf8x8_t = bf8x8_ocp_t; -using bf8x16_t = bf8x16_ocp_t; -using bf8x32_t = bf8x32_ocp_t; -using bf8x64_t = bf8x64_ocp_t; -#elif CK_FP8_TYPE_FNUZ -// f8 -using f8x2_t = f8x2_fnuz_t; -using f8x4_t = f8x4_fnuz_t; -using f8x8_t = f8x8_fnuz_t; -using f8x16_t = f8x16_fnuz_t; -using f8x32_t = f8x32_fnuz_t; -using f8x64_t = f8x64_fnuz_t; + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } -// bf8 -using bf8x2_t = bf8x2_fnuz_t; -using bf8x4_t = bf8x4_fnuz_t; -using bf8x8_t = bf8x8_fnuz_t; -using bf8x16_t = bf8x16_fnuz_t; -using bf8x32_t = bf8x32_fnuz_t; -using bf8x64_t = bf8x64_fnuz_t; -#endif + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; -// u8 -using uint8x2_t = typename vector_type::type; -using uint8x4_t = typename vector_type::type; -using uint8x8_t = typename vector_type::type; -using uint8x16_t = typename vector_type::type; -using uint8x32_t = typename vector_type::type; -using uint8x64_t = typename vector_type::type; +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; +#else template struct NumericLimits { __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } - __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } - __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } - __host__ __device__ static constexpr T QuietNaN() { return std::numeric_limits::quiet_NaN(); } - __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } }; @@ -2008,6 +2959,119 @@ struct NumericLimits return bit_cast(binary_qnan); } }; +#endif + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; template struct NumericUtils @@ -2028,6 +3092,7 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; using bitwise_type = uint32_t; }; @@ -2045,9 +3110,19 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; using bitwise_type = uint16_t; }; +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + template <> struct NumericUtils { @@ -2055,6 +3130,7 @@ struct NumericUtils static constexpr int mant = 3; static constexpr int bias = 8; // negative zero nan mode // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; }; template <> @@ -2064,6 +3140,7 @@ struct NumericUtils static constexpr int mant = 2; static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; }; template <> struct NumericUtils @@ -2082,11 +3159,109 @@ struct NumericUtils }; template <> -struct NumericUtils +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils { static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; }; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 03c4e16dd6e8afe9e954d1e67d55e21ddf3dbbe1..2b247cc02a001c4bb1797a3ef5a4386eaec3fc98 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP +#include "type.hpp" namespace ck { namespace debug { diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 0dcc514a2f6548d6ca4ac5f8d8c89ee09775131c..6de17a61522a226b25c9430a672a2f06532ff60f 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -29,6 +29,13 @@ struct DynamicBuffer ElementSpaceSize element_space_size_; T invalid_element_value_ = T{0}; + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { @@ -54,7 +61,8 @@ struct DynamicBuffer template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { @@ -81,14 +89,18 @@ struct DynamicBuffer return amd_buffer_load_invalid_element_return_zero, t_per_x, coherence>( - p_data_, i, is_valid_element, element_space_size_); + p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else { return amd_buffer_load_invalid_element_return_customized_value, t_per_x, coherence>( - p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + p_data_, + i, + is_valid_element, + element_space_size_ / PackedSize, + invalid_element_value_); } } else @@ -190,12 +202,13 @@ struct DynamicBuffer dst_buf.p_data_, dst_offset, is_valid_element, - element_space_size_); + element_space_size_ / PackedSize); } template >::type, - typename scalar_type>::type>::value, + typename scalar_type>::type>::value || + !is_native_type(), bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { @@ -224,7 +237,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_store, t_per_x, coherence>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && is_same>::type, int8_t>::value && @@ -376,7 +389,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_atomic_add, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else { @@ -415,7 +428,7 @@ struct DynamicBuffer constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; amd_buffer_atomic_max, t_per_x>( - x, p_data_, i, is_valid_element, element_space_size_); + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if(is_valid_element) { diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a692f533f8b9d891d6cef99ff879e54b75337103 --- /dev/null +++ b/include/ck/utility/e8m0.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type.hpp" + +namespace ck { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ +struct e8m0_bexp_t +{ + using type = uint8_t; + type data; + + constexpr static type bias = 127; + constexpr static type nan_mask = 0xFF; + + __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {} + __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {} + __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast(init & nan_mask)} + { + } + __host__ __device__ explicit constexpr e8m0_bexp_t(float scale) + : data{static_cast((bit_cast(scale) & (nan_mask << 23)) >> 23)} + { + } + + __host__ __device__ explicit constexpr operator float() const + { + if(data == nan_mask || data == 0) + { + uint32_t bits = data << 1; + bits |= 1; + bits <<= 22; + return bit_cast(bits); + } + else + { + uint32_t bits = data << 23; + return bit_cast(bits); + } + } + + __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const + { + // strict IEEE compliance for NaN + return data == other.data && data != nan_mask; + } + + __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } +}; + +namespace utils { + +template +__host__ __device__ inline int get_exponent_value(T x); + +template <> +__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +{ + return x.data; +} + +} // namespace utils + +} // namespace ck diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index c0a3c99f1fdafea9f151fe9fc319c2f7aaa0ffda..6ba63fc761c300c24e6013dce499f5d1c2ba9f27 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,14 +1,31 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once namespace ck { +#ifndef CK_CODE_GEN_RTC template using enable_if = std::enable_if; template using enable_if_t = typename std::enable_if::type; +#else +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; +#endif + } // namespace ck diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 6455402dcb331d91240ebe09d4d553f4d355f96e..809f302f743b1d9152afd98952010fddca92386a 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#ifndef CK_CODE_GEN_RTC #pragma once #include @@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck +#endif diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 91797d24092e3e32ad4a6bd40958952b124d9978..cd48ed17474480007f63180a7a25383172a3c8bd 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y) { if constexpr(predicate) { - return std::forward(x); + return ck::forward(x); } else { - return std::forward(y); + return ck::forward(y); } } diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index b5f3df8d7c517dfaf01320e41721da174883c2d9..8e86a296dc2ea0e1aca99a8480d5a826583ffd30 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP @@ -21,7 +21,7 @@ struct unpack_impl> template __host__ __device__ constexpr auto operator()(F&& f, X&& x) const { - return std::forward(f)(std::forward(x).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})...); } }; @@ -35,8 +35,8 @@ struct unpack2_impl, Sequence> template __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const { - return std::forward(f)(std::forward(x).At(Number{})..., - std::forward(y).At(Number{})...); + return ck::forward(f)(ck::forward(x).At(Number{})..., + ck::forward(y).At(Number{})...); } }; @@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x) { using X_ = remove_reference_t; return detail::unpack_impl::type>{}( - std::forward(f), std::forward(x)); + ck::forward(f), ck::forward(x)); } // TODO: properly implement unpack that takes any number of containers @@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) using Y_ = remove_reference_t; return detail::unpack2_impl::type, typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( - std::forward(f), std::forward(x), std::forward(y)); + ck::forward(f), ck::forward(x), ck::forward(y)); } } // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 376070eb3d8ac326603b71e52e76949c168f4219..75f35d762c52d465b39e718598082b0360905a5b 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant, integral_ return integral_constant{}; } +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; } // namespace ck diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index 7a324a6c458b3f1b8bb8037ccfac76e5eadceee0..a700fcfff1dd21de1a5784e0c74132f7812ab7fa 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -1,22 +1,24 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/integral_constant.hpp" + namespace ck { namespace detail { template class Op, class... Args> struct detector { - using value_t = std::false_type; + using value_t = integral_constant; using type = Default; }; template class Op, class... Args> -struct detector>, Op, Args...> +struct detector>, Op, Args...> { - using value_t = std::true_type; + using value_t = integral_constant; using type = Op; }; } // namespace detail @@ -32,12 +34,12 @@ template