diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..4161c2d5a4e54e731a356656bbff8864326c7fee --- /dev/null +++ b/.azuredevops/rocm-ci.yml @@ -0,0 +1,29 @@ +resources: + repositories: + - repository: pipelines_repo + type: github + endpoint: ROCm + name: ROCm/ROCm + +variables: +- group: common +- template: /.azuredevops/variables-global.yml@pipelines_repo + +trigger: + batch: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + +pr: none + +jobs: + - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 01e3bee0b615aa4125f9746f0261b743c8399bfe..459315e58b766043355046569379ab96500a3449 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,8 @@ -* @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex +* @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk # Documentation files -docs/* @ROCm/rocm-documentation -*.md @ROCm/rocm-documentation -*.rst @ROCm/rocm-documentation +docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk # Header directory for Doxygen documentation -library/include/* @ROCm/rocm-documentation +library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9e6678abe5247868c15fb7b47497a9154e4a6051..b3299fa4e8830d534931a29faf22d3a9020de951 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -15,4 +15,4 @@ python: build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" diff --git a/CHANGELOG.md b/CHANGELOG.md index fb2ba1975fc949560f62a58ca7ae35ef514a055c..dec6334cf5b596aacde281000685d7a081b986a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,38 +1,46 @@ # Changelog for Composable Kernel -Full documentation for Composable Kernel is not yet available. +Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## CK for ROCm 6.1.0 +## Composable Kernel 1.1.0 for ROCm 6.1.0 ### Additions + * Added generic instances for GEMM XDL operations (#1161) * Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133) * Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) * Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124) ### Optimizations + * New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135) ### Fixes + * Reduced the build time for most GPU architectures (#1084) * Fixed some conversion issues for fp8 data type (#1099) ### Changes + None ### Known issues + None -## CK for ROCm 6.0.0 +## Composable Kernel 1.1.0 for ROCm 6.0.0 ### Fixes - * Fixed a hazard associated with inline v_dot (#808) - * Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + +* Fixed a hazard associated with inline v_dot (#808) +* Fixed two bugs in grouped convolution backward data without K padding (#848 #876) ### Optimizations + None ### Additions + * Added an image to a column kernel (#867) * Added a column to an image kernel (#930) * Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) @@ -42,18 +50,22 @@ None * Support for Batched GEMM DL (#732) ### Changes - * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) -## CK 0.2.0 for ROCm 5.7.0 +* Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + +## Composable Kernel 0.2.0 for ROCm 5.7.0 ### Fixes + * Fixed a bug in 6-dimensional kernels (#555) * Fixed a test case failure with grouped convolution backward weight (#524) ### Optimizations + * Improved the performance of the normalization kernel ### Additions + * New CMake flags: * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types @@ -71,4 +83,5 @@ None * MaxPool and AvgPool forward (#815); MaxPool backward (#750) ### Changes + None diff --git a/CMakeLists.txt b/CMakeLists.txt index a3a9801cc60e07ce57488befb37b63f9d981f8f3..7e21a79764d8abe9dbb6672d0aaf38819a0d6621 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,10 +23,10 @@ endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) -find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) +find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") @@ -111,13 +111,21 @@ message("checking which targets are supported") #These targets will be filtered and only supported ones will be used #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) + if(NOT ENABLE_ASAN_PACKAGING) + #build CK for all supported targets rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + else() + #build CK only for xnack-supported targets + rocm_check_target_ids(DEFAULT_GPU_TARGETS + TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + endif() else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") endif() if(GPU_ARCH MATCHES "gfx90") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") @@ -127,20 +135,20 @@ else() rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + elseif(GPU_ARCH MATCHES "gfx12") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") endif() set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) - if(GPU_TARGETS) message("Building CK for the following targets: ${GPU_TARGETS}") else() - message("Building CK for the following targets: ${AMDGPU_TARGETS}") + message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}") endif() if (GPU_TARGETS) @@ -148,7 +156,7 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() - if (GPU_TARGETS MATCHES "gfx11") + if (GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() @@ -202,7 +210,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) +option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -210,10 +218,10 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() -if(USE_OPT_NAVI3X) +if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") + message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") endif() ## Threads @@ -225,7 +233,13 @@ link_libraries(Threads::Threads) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") + +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") @@ -436,6 +450,13 @@ if(BUILD_DEV) endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + add_compile_options(-fcolor-diagnostics) +endif() +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9) + add_compile_options(-fdiagnostics-color=always) +endif() + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") diff --git a/Dockerfile b/Dockerfile index 0d3807f4d4f2acc489db6cf23afdefa79d29d6cb..196b0ee1c3419331a42ed104c9f4a929aa672e44 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,11 +23,11 @@ RUN if [ "$ROCMVERSION" != "6.2" ]; then \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc2" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1736298; \ + elif [ "$ROCMVERSION" = "6.2" ] && [ "$compiler_version" = "rc3" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.2-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.2-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.2 rel-45 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=2003709; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" diff --git a/Jenkinsfile b/Jenkinsfile index ee841a18026d58c7fc42f48d8a3ec8907a784fe3..e9d55992d8894c802b167957dcf920d6646165e1 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -315,6 +315,10 @@ def buildHipClangJob(Map conf=[:]){ if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME @@ -366,6 +370,11 @@ def runCKProfiler(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " def variant = env.STAGE_NAME @@ -493,6 +502,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -515,38 +525,33 @@ def Build_CK(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //check whether running on Navi or MI300 node - def navi_node = 0 - def mi300_node = 0 + //check whether to run performance tests on this node + def do_perf_tests = 0 sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ - navi_node = 1 - echo "This is a Navi node" - } - if ( runShell('grep -n "gfx942" rocminfo.log') ){ - mi300_node = 1 - echo "This is MI300 node" + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){ + do_perf_tests = 1 + echo "Stash profiler and run performance tests" } cmake_build(conf) dir("build"){ //run tests and examples sh 'make -j check' - if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ + if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on Navi or MI300 nodes - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash name: "ckProfiler.tar.gz" + //do not stash profiler on nodes where we don't need to run performance tests + sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' + stash name: "ckProfiler.tar.gz" } - if (params.RUN_FULL_QA && mi300_node == 0 ){ - // build deb packages for all MI100/200/300 targets and prepare to export - sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - archiveArtifacts artifacts: 'composablekernel-tests_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash name: "ckprofiler_0.2.0_amd64.deb" + if (params.RUN_FULL_QA && do_perf_tests == 0 ){ + // build deb packages for all gfx9 targets and prepare to export + sh 'make -j package' + archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' + archiveArtifacts artifacts: 'composablekernel-tests_*.deb' + sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' + stash name: "ckprofiler_0.2.0_amd64.deb" } } - if (params.hipTensor_test && navi_node == 0 ){ + if (params.hipTensor_test && do_perf_tests == 0 ){ //build and test hipTensor sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip @@ -657,10 +662,11 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION= - 0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT= +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1; RUN_CK_TILE_TESTS=true + 0 21 * * * % ROCMVERSION=6.1;hipTensor_test=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false - 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" pipeline { agent none @@ -705,8 +711,8 @@ pipeline { description: "Select whether to build DL kernels (default: OFF)") booleanParam( name: "hipTensor_test", - defaultValue: true, - description: "Use the CK build to verify hipTensor build and tests (default: ON)") + defaultValue: false, + description: "Use the CK build to verify hipTensor build and tests (default: OFF)") string( name: 'hipTensor_branch', defaultValue: 'mainline', @@ -727,6 +733,14 @@ pipeline { name: "RUN_CODEGEN_TESTS", defaultValue: true, description: "Run the codegen tests (default: ON)") + booleanParam( + name: "RUN_CK_TILE_TESTS", + defaultValue: false, + description: "Run the ck_tile tests (default: OFF)") + booleanParam( + name: "BUILD_INSTANCES_ONLY", + defaultValue: false, + description: "Test building instances for various architectures simultaneously (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -809,22 +823,67 @@ pipeline { { parallel { - stage("Run Codegen Tests on MI100/MI200") + stage("Run Codegen Tests on gfx90a") { when { beforeAgent true expression { params.RUN_CODEGEN_TESTS.toBoolean() } } - options { retry(2) } - agent{ label rocmnode("gfx908 || gfx90a")} + agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx908;gfx90a" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check""" + -D GPU_TARGETS="gfx90a" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run CK_TILE Tests") + { + parallel + { + stage("Run CK_TILE Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/smoke_test_fwd.sh && \ + example/ck_tile/01_fmha/script/smoke_test_bwd.sh""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/smoke_test_fwd.sh && \ + example/ck_tile/01_fmha/script/smoke_test_bwd.sh""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -837,30 +896,30 @@ pipeline { { parallel { - stage("Build CK and run Tests on MI100/MI200/MI300") + stage("Build CK for all gfx9 targets") { when { beforeAgent true expression { params.RUN_FULL_QA.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } - stage("Build CK and run Tests on MI300") + stage("Build CK and run Tests on gfx942") { when { beforeAgent true @@ -868,45 +927,64 @@ pipeline { } agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } - stage("Build CK and run Tests on MI100/MI200") + stage("Build CK and run Tests on gfx90a") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } - agent{ label rocmnode("gfx908 || gfx90a") } + agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a" \ + -DGPU_TARGETS="gfx1100;gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } - stage("Build CK and run Tests on Navi21") + stage("Build CK instances for different targets") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() } } - agent{ label rocmnode("navi21") } + agent{ label rocmnode("gfx90a") } + environment{ + execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D INSTANCES_ONLY=ON \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Build CK and run Tests on gfx1030") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + } + agent{ label rocmnode("gfx1030") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -920,13 +998,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on Navi32") + stage("Build CK and run Tests on gfx1101") { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } - agent{ label rocmnode("navi32") } + agent{ label rocmnode("gfx1101") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -947,29 +1025,13 @@ pipeline { { parallel { - stage("Run ckProfiler: gfx90*") - { - when { - beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } - } - options { retry(2) } - agent{ label rocmnode("gfx908 || gfx90a")} - environment{ - setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """ - } - steps{ - runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') - cleanWs() - } - } stage("Run ckProfiler: gfx90a") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } } - options { retry(2) } + options { retry(1) } agent{ label rocmnode("gfx90a")} environment{ setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """ diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 710eca9f491a51f5357164c1d39386b680508b4b..e8c046ff44bef2d2f6b71ae029a11b51e3b1a9b7 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) -endif() \ No newline at end of file + + if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() +endif() diff --git a/client_example/07_grouped_convnd_fwd/common.hpp b/client_example/07_grouped_convnd_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..729af0b88b8d48a3361098aac1a155ad4e013308 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 4983ac33c3d759f3732fda2db526f35871d51322..d3a3111e945da61d30164aec4ee23b1c1c5e4689 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - std::array in_lengths{G, N, Wi, C}; - std::array in_strides{0, 0, 0, 1}; - - std::array wei_lengths{G, K, X, C}; - std::array wei_strides{0, 0, 0, 1}; - - std::array out_lengths{G, N, Wo, K}; - std::array out_strides{0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW - std::rotate(rbegin(in_lengths), - std::next(rbegin(in_lengths)), - std::next(rbegin(in_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(in_strides), - std::next(rbegin(in_strides)), - std::next(rbegin(in_strides), NumDimSpatial + 1)); - std::rotate(rbegin(wei_lengths), - std::next(rbegin(wei_lengths)), - std::next(rbegin(wei_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(wei_strides), - std::next(rbegin(wei_strides)), - std::next(rbegin(wei_strides), NumDimSpatial + 1)); - std::rotate(rbegin(out_lengths), - std::next(rbegin(out_lengths)), - std::next(rbegin(out_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(out_strides), - std::next(rbegin(out_strides)), - std::next(rbegin(out_strides), NumDimSpatial + 1)); - - std::array filter_strides{1}; - std::array filter_dilations{1}; - std::array input_left_pads{1}; - std::array input_right_pads{1}; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C + - sizeof(WeiDataType) * G * K * X * C + - sizeof(OutDataType) * G * N * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 93833506290ce7580b43ea7860bc7606778abb86..fb8a410ab39b71b56a4dfb4df8b62aab16e39d35 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W static constexpr ck::index_t Ho = 28; // output H static constexpr ck::index_t Wo = 28; // output W -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space - // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW - // Hence, we need to adjust the order of stride - std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; - std::array wei_lengths{G, K, C, Y, X}; - std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; - - std::array filter_strides{1, 1}; - std::array filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; - - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * N * Ho * Wo * G * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index 1a36490ef4d279aae5fa25ea8b3a55819842df82..541a0a19a07fb778eeff2f3fdf25b25cfbc1ab4f 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -160,6 +160,10 @@ bool run_grouped_conv_bwd_weight( auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 23311b4024e476bd8a93e340c7f42ea3a6d55b7e..5279e3dfcf8196533f59c419d2b3c81d0b53c249 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,22 +7,6 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) - target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) - target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) - target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp) - target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index d4d5c545c9aab231b183109e7f308b584bd38637..37bace920d20331680025cb18947229148fa31f7 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -35,6 +35,30 @@ target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composa add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd convinvscale +add_executable(client_conv3d_fwd_convinvscale_fp8 + grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale + ReLU +add_executable(client_conv3d_fwd_convscale_relu_fp8 + grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_relu_fp8 PRIVATE composable_kernel::device_conv_operations) +# Fwd convscale +add_executable(client_conv3d_fwd_convscale_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_fp8_bf8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_conv3d_fwd_convscale_bf8_fp8 + grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp) +target_link_libraries(client_conv3d_fwd_convscale_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) # Bwd data bilinear add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7059e24d8e1d69524c5be27c5f427a9f93c82049 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convinvscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvInvscale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvInvscale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..775ea99ecd520304fcf90d242e96d3cec330430a --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convinvscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..51eec5b1ab14dd3e66e026fdedcb8a9b41a58b3d --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScale, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScale{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f901d08ab6eaf784dfd86e337cffe76725798c04 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..192c4fdcb904a4b65ba4e22fdbc7caa1dd1fb35f --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15d063c2f139d35c6824557cb615099c93513fd6 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b38225f2b916b3cd347d95190a2581b14d494684 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ee188429b4cb4750dc7776bd199d56e005788a1d --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/common.hpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd_convscale_relu( + std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 elementwise Relu + std::size_t flop = GetFlops(out_lengths, wei_lengths, ds_size); + std::size_t num_bytes = + in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + ConvScaleRelu, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{}, + std::array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ConvScaleRelu{scale_in, scale_wei, scale_out}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4003dc7c86c65369b9bc09225ca227d81a7ffa80 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using CShuffleDataType = float; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd_convscale_relu( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp index 59c5c243ce9b6b8a513973d2c9198bebe61d4a94..23245dd1889e3890e3c474075d06f661c915dee6 100644 --- a/client_example/25_wrapper/wrapper_basic_gemm.cpp +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -7,19 +7,23 @@ #include #include +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/host_tensor.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), @@ -213,3 +225,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index 2a4034d62fad4217a438e439de8659002433825d..ceccc5eb8fe965f025f0569d4fc6bb6d31af3922 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -181,4 +181,3 @@ int main(int argc, char* argv[]) {1, 1, 1} /*filter_dilations*/); return 0; } -// MI100 Perf: 0.255178 ms, 1698.9 GB/s, diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp index b6294c2393a89d9037fc03df5ef0aae8e3ff04da..31e20342df31c7f228ef0f193da7f32f0f033550 100644 --- a/client_example/25_wrapper/wrapper_optimized_gemm.cpp +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -7,18 +7,21 @@ #include #include -#include "ck/library/utility/host_tensor.hpp" +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #include "ck/host_utility/kernel_launch.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" #include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" struct SimpleDeviceMem { @@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M, int main(int argc, char* argv[]) { + bool is_supported = ck::is_xdl_supported(); + if(!is_supported) + { + std::cout << "WARNING: xdl example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + using DataType = ck::half_t; const auto thread_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), @@ -305,3 +316,4 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } +#endif diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp index 36637df46d57afff376619008d37bc2c8bbffbaf..4b284c74d4a75cf3634b77e703d369df44ba8098 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp index f71b6a13feecdd497ec2a91878831b65e275ac9f..6cc83e06f68555f84d642219e3f49aafd627fa66 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 8eb662d281b302ee522892ea812a3eb508c98d6b..d2222a840ef536ebd8c668e5242fc092d41b5bc4 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -6,46 +6,36 @@ if (DTYPES) add_definitions(-DDTYPES) if (DTYPES MATCHES "int8") add_definitions(-DCK_ENABLE_INT8) - if(NOT DEFINED ${CK_ENABLE_INT8}) - set(CK_ENABLE_INT8 "ON") - endif() + set(CK_ENABLE_INT8 "ON") endif() if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) - if(NOT DEFINED ${CK_ENABLE_FP8}) - set(CK_ENABLE_FP8 "ON") - endif() + set(CK_ENABLE_FP8 "ON") + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) - if(NOT DEFINED ${CK_ENABLE_FP16}) - set(CK_ENABLE_FP16 "ON") - endif() + set(CK_ENABLE_FP16 "ON") endif() if (DTYPES MATCHES "fp32") add_definitions(-DCK_ENABLE_FP32) - if(NOT DEFINED ${CK_ENABLE_FP32}) - set(CK_ENABLE_FP32 "ON") - endif() + set(CK_ENABLE_FP32 "ON") endif() if (DTYPES MATCHES "fp64") add_definitions(-DCK_ENABLE_FP64) - if(NOT DEFINED ${CK_ENABLE_FP64}) - set(CK_ENABLE_FP64 "ON") - endif() + set(CK_ENABLE_FP64 "ON") endif() if (DTYPES MATCHES "bf16") add_definitions(-DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_BF16}) - set(CK_ENABLE_BF16 "ON") - endif() + set(CK_ENABLE_BF16 "ON") endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) - if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES}) - set(CK_ENABLE_ALL_DTYPES "ON") - endif() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + set(CK_ENABLE_ALL_DTYPES "ON") endif() if (GPU_TARGETS) @@ -73,7 +63,8 @@ message(STATUS "Build with HIP ${hip_VERSION}") # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build")) + IF(IS_DIRECTORY "${subdir}" AND (NOT "${subdir}" MATCHES "build") + AND (NOT "${subdir}" MATCHES ".vscode")) add_subdirectory(${subdir}) ENDIF() ENDFOREACH() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 8654170b3ddf326d292b00bcb9408c81dfe7e7f2..93fd306e98af3bf86fcc6f0f213d029f6f3c4a26 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -2,7 +2,7 @@ # # MIT License # -# Copyright (c) 2017 Advanced Micro Devices, Inc. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt @@ -96,6 +96,7 @@ else() -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture + -Wno-nvcc-compat ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 72549c9a4ef40490c0284b47e22420bbbcdfacba..d8b22fc943be61390f169c8109595495ace20d8d 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.16) -project(composable_kernel_host) +project(composable_kernel_host LANGUAGES CXX HIP) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -12,24 +12,38 @@ find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +add_compile_options(-std=c++17) +find_package(hip) +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") + +# add include directories +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/library/include + ${HIP_INCLUDE_DIRS} + ) + list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) + ${CK_ROOT}/include/ck/*.hpp) message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "RELATIVE: ${CK_ROOT}/include") add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_definitions(-std=c++17) - file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) target_include_directories(ck_host PUBLIC $ diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index dfd513106b0b38600b234696035575c37bdb8aed..c7d295de943e1feb5b139d933118db745e7edee3 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -5,24 +5,27 @@ #include #include #include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/stringutils.hpp" using ck::host::Transform; struct Emitters { + // retrieve the hard-coded instances provided, template them, and then store them in a map std::unordered_map()>> m; template - void Register(const std::string& name) + void Register(const std::string& name, const std::string& prologue, const std::string& epilogue) { - m[name] = [] { - auto configs = T::CreateOperations(); + m[name] = [&] { + auto configs = T::CreateOperations(prologue, epilogue); return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); }; } + // takes in an operation instance and uses it to substitute the correct values into the template template static std::string ToTuple(const T& ops) { @@ -31,6 +34,7 @@ struct Emitters return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; } + // Join together all the strings in the map std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } std::vector List() const @@ -43,9 +47,38 @@ int main(int argc, const char* argv[]) { std::string prog = argv[0]; std::vector args(argv + 1, argv + argc); + + // Specify problem type and problem size + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + + // user provided fusion + std::string prologue = ""; + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};)"; + + // Load in operations into the Register Emitters e; e.Register( - "DeviceGemmMultipleD_Xdl_CShuffle"); + "DeviceGemmMultipleD_Xdl_CShuffle", prologue, epilogue); if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { return arg == "-h" or arg == "--help"; @@ -64,6 +97,7 @@ int main(int argc, const char* argv[]) return 0; } + // print out all the instances for the operation that was chosen at the command line for(auto name : args) std::cout << e.Emit(name) << std::endl; diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp index 88e040db53fd1c9d064ce65fc86cee3b0cdb49b9..02c19c88e7335c467da119dd46fb3408b80b319e 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index f9d39633ac7d3ebe484246dd37012cb880d84718..359da7d8cf5ee48aab4cd6a4e987a49830fac88c 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -14,10 +14,15 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines all values need for an instance of fwd conv struct Operation_Xdl_CShuffle { - static std::vector> CreateOperations(); - static std::vector CreateOperations(const Problem& prob); + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); TensorDesc A{}; TensorDesc B{}; DataType acc = DataType::Float; @@ -27,13 +32,21 @@ struct Operation_Xdl_CShuffle std::string a_elem_op = PassThrough; std::string b_elem_op = PassThrough; std::string cde_elem_op = Bilinear; + std::string prologue = ""; + std::string epilogue = ""; std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters operation::TileDesc tile_desc{}; operation::BlockTransferDesc a_block_transfer{}; operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_); + // returns a templated instance Solution ToSolution() const; }; diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp index f6dbc2b6e877f9a74911b3aa9f53e0ce2c72db79..f4036328ecc6848bc559a371d0d05fcbf482374e 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,11 +12,14 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// defines the problem specification for a GEMM operation struct Problem { - std::size_t M = 0; - std::size_t N = 0; - std::size_t K = 0; + // dimensions for GEMM operation + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + // layouts for tensors bool TransA = false; bool TransB = false; bool TransE = false; @@ -29,9 +32,13 @@ struct Problem std::string BElementOp = PassThrough; std::string CDEElementOp = PassThrough; + // returns the correct device op file for the operation std::string GetIncludeHeader() const; - std::vector GetSolutions(const std::string& arch) const; + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; }; } // namespace device_gemm_multiple_d diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5ad1dce1762b847f7385400608d5a776317676b3 --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the values needed for an instance of forward convolution and functions to return +// (templated) instances +struct Operation_Conv_Fwd_Xdl_Cshuffle +{ + // returns a vector of instances given the fusion operations, uses default values for problem + // spec + static std::vector + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, provided with a problem spec and fusion operations + static std::vector CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue); + std::size_t NumDim; + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = PassThrough; + std::string prologue = ""; + std::string epilogue = ""; + std::string conv_specialization = + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"; + std::string gemm_specialization = + "ck::tensor_operation::device::GemmSpecialization::MNKPadding"; + // tuning parameters + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + // functions to update fusion operations if they are provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..433f9a8fc92f100486780f0a6d6e246d2e83a43f --- /dev/null +++ b/codegen/include/ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace conv { + +// defines the problem specification for a forward convolution operation +struct Problem_Conv_Fwd +{ + std::size_t NumDim = 0; + // size of a forward convolution operation + std::size_t G = 0; + std::size_t N = 0; + std::size_t C = 0; + std::size_t Hi = 0; + std::size_t Wi = 0; + std::size_t Ho = 0; + std::size_t Wo = 0; + std::size_t K = 0; + std::size_t Y = 0; + std::size_t X = 0; + Layout ALayout = Layout::NHWGC; + Layout BLayout = Layout::GKYXC; + Layout ELayout = Layout::NHWGK; + std::vector DsLayout = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::tensor_operation::element_wise::PassThrough"; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 3da05baaaf06948464ae9866833bfc7a214c8ba8..54f8d9f7317c3511a12c723bd6eadfaba711ffa9 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include #include #include #include diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index f587122b058dfe978bd89bc6c13e882d18a258df..84ef92f0a039706e1da4719ca9576667fac44494 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 01374b86c8bbbc0f6de40820d52ac6e7f3716223..89c1884d2e4283604eecf33cf6b5dffe42e1c367 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 23488a66d0e48880e3462a2c8c3e9a69b4ab1484..812c073678bb663cbdd884bba9604963ed4a02bb 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,6 +12,7 @@ namespace ck { namespace host { +// holds the templated instance, substitues values into template from instancess struct Solution { @@ -33,6 +34,7 @@ struct Solution std::unordered_map template_values; }; +// supported data types enum class DataType { Half, @@ -40,22 +42,28 @@ enum class DataType Int8, Int32 }; - std::string ToString(DataType dt); +// supported layouts: gemm and fwd conv enum class Layout { Row, - Column + Column, + GKYXC, + GKCYX, + GNHWK, + GNHWC, + NHWGC, + NHWGK }; - std::string ToString(Layout dl); +Layout ToLayout(bool Trans); // returns the layout for gemm +// supported GEMM types enum class GemmType { Default }; - std::string ToString(GemmType gt); struct TensorDesc diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp index e8785a456f3a1725dee535a81ec7f2006819e7e1..21926814f1d27a0727d952a1a5674e7ada520012 100644 --- a/codegen/include/ck/host/utils.hpp +++ b/codegen/include/ck/host/utils.hpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include +#include +#include namespace ck { namespace host { @@ -12,6 +14,5 @@ namespace host { std::size_t integer_divide_ceil(std::size_t x, std::size_t y); const std::unordered_set& get_xdlop_archs(); - } // namespace host } // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp index ec25afc0f986e745fdb2539f104bcf98be3f2192..44bc051a8b021e1e3a7461732034dc77e063ae76 100644 --- a/codegen/src/device_gemm_multiple_d.cpp +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" @@ -11,23 +11,28 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// return the relevant device op file based on the operation std::string Problem::GetIncludeHeader() const { return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; } -std::vector Problem::GetSolutions(const std::string& arch) const +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const { if(get_xdlop_archs().count(arch) == 0) return {}; - auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this); + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances std::vector result; std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { - return op.ToSolution(); + return op.ToSolution(); // template instance with correct values }); return result; } } // namespace device_gemm_multiple_d } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index 9e397497eeb0e2ef599b7df18e83ba4d6e5aa1f5..a2e8eccbf107b30e7b7120016fd864964859e43d 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -10,6 +10,7 @@ namespace ck { namespace host { namespace device_gemm_multiple_d { +// calculate appropriate Gemm Specification based on input tensor dimensions static std::string GetGemmSpec(const std::size_t m, const std::size_t n, const std::size_t k, @@ -30,9 +31,40 @@ static std::string GetGemmSpec(const std::size_t m, return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; } +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } -std::vector Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) { std::vector result; @@ -155,6 +187,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con // clang-format on }; + // choose correct arrangement of tuning parameters based on the layout of each tensor const auto a_block_descriptions = prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; const auto b_block_descriptions = @@ -165,6 +198,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); + // Put all values together into a single operation > store into the result vector for(std::size_t i = 0; i < tile_descriptions.size(); i++) { Operation_Xdl_CShuffle x; @@ -188,12 +222,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(con x.tile_desc.m_per_block, x.tile_desc.n_per_block, x.tile_desc.k_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); result.push_back(x); } return result; } -std::vector> Operation_Xdl_CShuffle::CreateOperations() +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) { std::vector problems; for(bool TransA : {true, false}) @@ -204,7 +243,8 @@ std::vector> Operation_Xdl_CShuffle::CreateO prob.TransB = TransB; problems.push_back(prob); } - return Transform(problems, [](const Problem& p) { return CreateOperations(p); }); + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); } static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = @@ -224,9 +264,20 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " "${CDEBlockTransferScalarPerVector_NPerBlock}>"; +// use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const { std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, {"LayoutA", ToString(this->A.layout)}, {"LayoutB", ToString(this->B.layout)}, {"LayoutDs", diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c689e5ec950e5a8eae22a8328150a895f4c03048 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/utils.hpp" +#include +#include + +namespace ck { +namespace host { +namespace conv { + +// return the relevant device op file based on the operation +// NOTE: this is a modified version of the original CK file that calls the kernel from a device +// function and makes the Argument class accessible on the device +std::string Problem_Conv_Fwd::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/" + "codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"; +} + +// return vector of forward convolution instances when provided with a problem instance +std::vector Problem_Conv_Fwd::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::conv::Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + *this, prologue, epilogue); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94161a76d999d838fe32c3ac38af222e4d857037 --- /dev/null +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace conv { + +// calculate appropriate Gemm Specification based on input tensor dimensions +// NOTE: in CK, MNKPadding is always used for forward convolution +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue) +{ + if(!prologue.empty()) + { + this->prologue = prologue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->prologue = ""; + } +} + +void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue) +{ + if(!epilogue.empty()) + { + this->epilogue = epilogue; + this->cde_elem_op = "CDEElementOp"; + } + else + { + this->epilogue = ""; + } +} + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations( + const Problem_Conv_Fwd& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1} + // clang-format on + }; + + std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector b_block_descriptions = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1} + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1} + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 4>, 1}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8} + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Conv_Fwd_Xdl_Cshuffle x; + x.NumDim = prob.NumDim; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, prob.ALayout}; + x.B = TensorDesc{prob.BDataType, prob.BLayout}; + x.E = TensorDesc{prob.EDataType, prob.ELayout}; + x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { + return TensorDesc{dt, lo}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values +std::vector +Operation_Conv_Fwd_Xdl_Cshuffle::CreateOperations(const std::string& prologue, + const std::string& epilogue) +{ + Problem_Conv_Fwd prob; + return CreateOperations(prob, prologue, epilogue); +} + +static const char* const CopyDevice_ConvTemplate = + R"( +${Prologue} +${Epilogue} + +using CDEElementOp = Epilogue; +using DeviceConv = ck::tensor_operation::device::CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<${NumDim}, ${LayoutA}, ${LayoutB}, ${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, ${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, ${CDEElementwiseOperation}, ${ConvSpecialization}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, ${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, ${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, ${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, ${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, ${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, ${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, ${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, ${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, ${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, ${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, ${CDEBlockTransferScalarPerVector_NPerBlock}>; + +constexpr ck::index_t NumATensor = ck::tensor_operation::device::GetNumABTensors(); +constexpr ck::index_t NumBTensor = ck::tensor_operation::device::GetNumABTensors(); + +extern "C" __global__ void run_${name}( + const ${ADataType}* in_dev, + const ${BDataType}* wei_dev, + ${EDataType}* __restrict__ out_dev, + ck::Array in_lengths, + ck::Array in_strides, + ck::Array wei_lengths, + ck::Array wei_strides, + ck::Array out_lengths, + ck::Array out_strides, + ck::Array conv_filter_strides, + ck::Array conv_filter_dilations, + ck::Array input_left_pads, + ck::Array input_right_pads, + const ${AElementwiseOperation} a_element_op, + const ${BElementwiseOperation} b_element_op, + const ${CDEElementwiseOperation} cde_element_op +){ + + + auto arg = DeviceConv::Argument(in_dev, + wei_dev, + ck::Array{}, + out_dev, + in_lengths, + in_strides, + wei_lengths, + wei_strides, + ck::Array, 0>{}, + ck::Array, 0>{}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + ${AElementwiseOperation}{}, + ${BElementwiseOperation}{}, + ${CDEElementwiseOperation}{1.0f, 1.0f}); + + constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler(); + + // GridwiseGemm + using GridwiseGemm = DeviceConv::GridwiseGemm; + + static constexpr auto I0 = ck::Number<0>{}; + + ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ${ADataType}*, + const ${BDataType}*, + typename GridwiseGemm::DsGridPointer, + ${EDataType}, + ${AElementwiseOperation}, + ${BElementwiseOperation}, + ${CDEElementwiseOperation}, + DeviceConv::AGridDesc_AK0_M_AK1, + DeviceConv::BGridDesc_BK0_N_BK1, + DeviceConv::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceConv::Block2ETileMap, + ck::tensor_operation::device::ComputePtrOffsetOfStridedBatch, + ck::integral_constant{}, + false, + false> + ( + arg.p_as_grid_.At(I0), + arg.p_bs_grid_.At(I0), + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_ + ); + +} +)"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Conv_Fwd_Xdl_Cshuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.m_per_block) + "_" + + std::to_string(this->tile_desc.n_per_block) + "_" + + std::to_string(this->tile_desc.k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"NumDim", std::to_string(this->NumDim)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"ComputeDataType", ToString(this->A.element)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"Prologue", this->prologue}, + {"Epilogue", this->epilogue}, + {"ConvSpecialization", this->conv_specialization}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(CopyDevice_ConvTemplate, values), std::move(values)}; +} + +} // namespace conv +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 6fcb94cdbdf0c7a53ca9c21d8f57bc0bf9684142..f685aca044156403c01052f61aab3b711a9d8d71 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -14,4 +14,4 @@ std::unordered_map GetHeaders() } } // namespace host -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index d43df73f3386e872f7cbf784da66f8d4b118da05..a8a8b10c04d522e93dc7340167e46ca83f51259b 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -29,12 +29,20 @@ std::string ToString(DataType dt) throw std::runtime_error("Incorrect data type"); } +Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + std::string ToString(Layout dl) { switch(dl) { case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + case Layout::GKCYX: return "ck::tensor_layout::convolution::GKCYX"; + case Layout::GKYXC: return "ck::tensor_layout::convolution::GKYXC"; + case Layout::GNHWK: return "ck::tensor_layout::convolution::GNHWK"; + case Layout::GNHWC: return "ck::tensor_layout::convolution::GNHWC"; + case Layout::NHWGC: return "ck::tensor_layout::convolution::NHWGC"; + case Layout::NHWGK: return "ck::tensor_layout::convolution::NHWGK"; } throw std::runtime_error("Incorrect layout"); } diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index cd6700c4895b87efe26683a848ab42588a1fe27f..19627d4cf6fa6bcdacee4519a1929b93093cf527 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host/utils.hpp" diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 897cce1c94710c0f16a8980ef6aac5a2ba05dd72..f891286019a227a16324393c9c0796961b33ec97 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,11 +1,13 @@ - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) - file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) foreach(TEST_SRC ${TEST_SRCS}) -get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) -rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) -target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) -target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) + # target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a) + target_include_directories(test_host_${BASE_NAME} PUBLIC include()) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/include) + target_include_directories(test_host_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) endforeach() diff --git a/codegen/test/common.hpp b/codegen/test/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99d4c6497331f65d19adf302bd47dbaa22ac4b40 --- /dev/null +++ b/codegen/test/common.hpp @@ -0,0 +1,134 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +std::size_t GetSize(V mLens, V mStrides) +{ + std::size_t space = 1; + for(std::size_t i = 0; i < mLens.Size(); ++i) + { + if(mLens[i] == 0) + continue; + + space += (mLens[i] - 1) * mStrides[i]; + } + return space; +} + +template +rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + rtc::buffer result(space); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + // std::fill(result.begin(), result.end(), 1); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 17b659993a8a33d56fc97d47daa9cfd6f35f026f..bd7ef463fbe64d5bc3d07665cb4757598657f2ad 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -10,6 +10,7 @@ #include #include #include +#include using half = _Float16; // using half = __fp16; @@ -159,7 +160,10 @@ TEST_CASE(test_problem_kernel) auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto c = to_gpu(generate_buffer(1024 * 1024, 2)); - for(auto solution : prob.GetSolutions("gfx90a")) + std::string epilogue = ""; + std::string prologue = ""; + + for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) { auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, @@ -178,6 +182,7 @@ TEST_CASE(test_problem_kernel) auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * ck::host::integer_divide_ceil(prob.N, n_per_block); k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + CHECK(report(solution, check(rtc::from_gpu(c)))); } } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c477692e5d7cceef6cd5c57088affd051b33bad --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include "common.hpp" +#include + +// Need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec9bd2b78159e9af39721115c3f1bef9c06b1dfa --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include +#include +#include +#include + +// need this for validation +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9850184c5ec0c6a957323898de1af03897971188 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {2, 2}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {0, 0}; + ck::Array input_right_pads = {0, 0}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {2, 2}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {0, 0}; + std::vector input_right_pads_ = {0, 0}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp new file mode 100644 index 0000000000000000000000000000000000000000..907f744db402183347d44aeb63083c68866d0128 --- /dev/null +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -0,0 +1,209 @@ +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" +#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/tensor_operation/gpu/device/helper.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "common.hpp" +#include +#include +#include +#include + +// need this for verification +/**struct Epilogue +{ + Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +};**/ +const std::string conv_compile_check = R"__ck__( +#include <${include}> + +${template}; + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + // set up problem specification + ck::host::conv::Problem_Conv_Fwd prob; + prob.NumDim = 2; + prob.G = 32; + prob.N = 256; + prob.C = 32; + prob.K = 64; + prob.Y = 3; + prob.X = 3; + prob.Hi = 28; + prob.Wi = 28; + prob.Ho = 28; + prob.Wo = 28; + check_all check; + + // user provided fusion operations + std::string epilogue = R"( +struct Epilogue +{ + __host__ __device__ Epilogue(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()(ck::half_t& e, + const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * e + beta_ * ck::type_convert(d)); + } + + float alpha_; + float beta_; +}; +)"; + std::string prologue = ""; + + // length+stride arrays + ck::Array in_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.C), + static_cast(prob.Hi), + static_cast(prob.Wi)}; + ck::Array out_lengths{static_cast(prob.G), + static_cast(prob.N), + static_cast(prob.K), + static_cast(prob.Ho), + static_cast(prob.Wo)}; + ck::Array wei_lengths{static_cast(prob.G), + static_cast(prob.K), + static_cast(prob.C), + static_cast(prob.Y), + static_cast(prob.X)}; + ck::Array d_lengths = {}; + + ck::Array in_strides{static_cast(prob.C), + static_cast(prob.Hi * prob.Wi * prob.G * prob.C), + 1, + static_cast(prob.Wi * prob.G * prob.C), + static_cast(prob.G * prob.C)}; + ck::Array out_strides{static_cast(prob.K), + static_cast(prob.Ho * prob.Wo * prob.G * prob.K), + 1, + static_cast(prob.Wo * prob.G * prob.K), + static_cast(prob.G * prob.K)}; + ck::Array wei_strides{static_cast(prob.K * prob.Y * prob.X * prob.C), + static_cast(prob.Y * prob.X * prob.C), + 1, + static_cast(prob.X * prob.C), + static_cast(prob.C)}; + ck::Array d_strides = {}; + + ck::Array conv_filter_strides = {1, 1}; + ck::Array conv_filter_dilations = {1, 1}; + ck::Array input_left_pads = {1, 1}; + ck::Array input_right_pads = {1, 1}; + + // move the data onto the device + auto in_dev = + to_gpu(generate_buffer>(in_lengths, in_strides, 0)); + auto wei_dev = + to_gpu(generate_buffer>(wei_lengths, wei_strides, 1)); + auto out_dev = + to_gpu(generate_buffer>(out_lengths, out_strides, 2)); + + // CK Verficiation: Reference Kernel + /**bool pass = true; + Tensor in_host(in_lengths, in_strides); + in_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor wei_host(wei_lengths, wei_strides); + wei_host.GenerateTensorValue(GeneratorTensor_1{1}); + Tensor out_host(out_lengths, out_strides); + + std::vector conv_filter_strides_ = {1, 1}; + std::vector conv_filter_dilations_ = {1, 1}; + std::vector input_left_pads_ = {1, 1}; + std::vector input_right_pads_ = {1, 1}; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + 2, + ck::half_t, + ck::half_t, + ck::half_t, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Epilogue>(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_host, + wei_host, + out_host, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + Epilogue{1.0f, 1.0f}); + out_host.SetZero(); + ref_invoker.Run(ref_argument);**/ + + for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue)) + { + // substitute instance values into the template + auto src = ck::host::InterpolateString( + conv_compile_check, + {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}}); + + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + auto name = solution.GetTemplateParameter("name"); + options.kernel_name = "run_" + name; + auto k = rtc::compile_kernel(srcs, options); + + // Grid size calculation + auto block_size = solution.GetTemplateParameter("BlockSize"); + + auto tmp = get_launch_params(solution, out_lengths, out_strides); + + auto grid_size = tmp * in_lengths[1]; + + // launch the kernel with arguments needed for the argument pointer + k.launch(nullptr, grid_size * block_size, block_size)(in_dev.data(), + wei_dev.data(), + out_dev.data(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + // auto res = rtc::from_gpu(out_dev); + // pass &= ck::utils::check_err(res, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + // assert(pass); + + // Simple check: this checks that the output from each instance matches the output from the + // first instance + CHECK(report(solution, check(rtc::from_gpu(out_dev)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 7ea55b9328af215cba5602e6801c71f53da2ee4c..d84ebf4de9ed8cfaaedbb3bb3f67a455c6dfbf3a 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -56,6 +56,8 @@ void write_string(const std::string& filename, const std::string_view& buffer) } std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } +// TODO: undo after extracting the codeobj +// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } kernel compile_kernel(const std::vector& srcs, compile_options options) { @@ -89,6 +91,12 @@ kernel compile_kernel(const std::vector& srcs, compile_options options auto obj = read_buffer(out_path.string()); + std::ofstream ofh("obj.o", std::ios::binary); + for(auto i : obj) + ofh << i; + ofh.close(); + // int s = std::system(("/usr/bin/cp " + out_path.string() + " codeobj.bin").c_str()); + // assert(s == 0); return kernel{obj.data(), options.kernel_name}; } diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 10e38c9adb8a1351e65b81d0c998fcb107388446..747f83e3baa240159adcf2e89847f4a1bad245a8 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace rtc { @@ -49,7 +50,10 @@ std::size_t get_available_gpu_memory() size_t total; auto status = hipMemGetInfo(&free, &total); if(status != hipSuccess) - throw std::runtime_error("Failed getting available memory: " + hip_error(status)); + { + std::cerr << "Failed getting available memory: " + hip_error(status) << std::endl; + return (8ull * 1024ull * 1024ull * 1024ull); + } return free; } diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index a854542439f116c1f70ce6ad28ae4a95320ff593..51bfef2898beb5427b82b97971c8d8a231611261 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.38.1 +rocm-docs-core==1.5.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 801726ed608300bc33e6225c13d5801add69b716..6d2fe6ca57868feaf671b07729889d04bf6d0396 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile requirements.in @@ -48,12 +48,6 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 - # via - # sphinx - # sphinxcontrib-bibtex -importlib-resources==6.1.0 - # via rocm-docs-core jinja2==3.1.2 # via # myst-parser @@ -99,8 +93,6 @@ pyjwt[crypto]==2.6.0 # via pygithub pynacl==1.5.0 # via pygithub -pytz==2023.3.post1 - # via babel pyyaml==6.0 # via # myst-parser @@ -111,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.38.1 +rocm-docs-core==1.5.0 # via -r requirements.in six==1.16.0 # via @@ -165,7 +157,3 @@ urllib3==1.26.18 # via requests wrapt==1.15.0 # via deprecated -zipp==3.17.0 - # via - # importlib-metadata - # importlib-resources diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 23683de449961827a4ddbf6d41e286217b367489..98fd9c6b773b147327aae0aa2817ea50c014b030 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) +add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3) add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) diff --git a/example/01_gemm/README.md b/example/01_gemm/README.md index 226783b03b0b58cfe4de15a189f07a8324b3bc1d..5edec1f04378963686ab8915b1b3ede93dc76081 100644 --- a/example/01_gemm/README.md +++ b/example/01_gemm/README.md @@ -8,16 +8,20 @@ ./bin/example_gemm_xdl 0 1 5 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` +# Instructions for ```example_gemm_xdl_fp16_streamk_v3``` + +## Run ```example_gemm_xdl_fp16_streamk_v3``` +```bash +arg1: verification (0=no, 1=yes) +arg2: initialization (0=no init, 1=integer value, 2=decimal value) +arg3: time kernel (0=no, 1=yes) +arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK) +arg11: Grid_size(-1 for max occupancy) +bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1 a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s +problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1} +Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2 ``` diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index ef87d9c2fc4767af74a63521be0480b022c13aba..3d8f4565cbf0936e6080a9cadbf1bfe0f3f6a142 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,6 +45,19 @@ struct ProblemSizeStreamK final ck::index_t NumSKBlocks = -1; }; +struct ProblemSizeStreamK_universal final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK +}; struct ProblemSizeSplitK final { @@ -123,6 +136,57 @@ bool parse_cmd_args(int argc, return true; } +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK_universal& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.Streamk_sel = std::stoi(argv[10]); + problem_size.Grid_size = std::stoi(argv[11]); + } + } + else + { + std::cerr + << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + return false; + } + + return true; +} + template <> bool parse_cmd_args(int argc, char* argv[], @@ -165,7 +229,8 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: NumSKBlocks(optional)" << std::endl; + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 8c52e4f7d703631edd5504967d96647b4c130832..f8afe8d6db4ffba5e3c2ab4e4a31b997decad906 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle - < ALayout, - BLayout, - CLayout, - ADataType, + < ALayout, + BLayout, + CLayout, + ADataType, BDataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CElementOp, - GemmDefault, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, 1, // Prefetch stage 128, // BlockSize 64, // MPerBlock 128, // NPerBlock 64, // KPerBlock - 8, // K1 + 2, // K1 16, // MPerWmma 16, // NPerWmma 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, 1, // C shuffle (M Repeat) Per store 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, + S<1, 32, 1, 4>, 8>; // clang-format on diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b163962b95132e2749151f88c682968417fd361 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index b04e4e53a89db56406ff873d8ea0353fa354f049..cb15186c3ba51a1a736272ee509e69ab84b876f8 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; case 4: - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); break; case 5: diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc new file mode 100644 index 0000000000000000000000000000000000000000..6679f95157a638979c27ccd9cebff7eece444146 --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto Grid_size = problem_size.Grid_size; + auto Streamk_sel = problem_size.Streamk_sel; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) { + if(streamk_sel == -1) + { + return static_cast(4); + } + else + return static_cast(streamk_sel); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Streamk_sel = f_get_default_streamk_policy(Streamk_sel); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2_Streamk_Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + Streamk_sel, + Grid_size, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_universal_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK_universal problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/02_gemm_bilinear/README.md b/example/02_gemm_bilinear/README.md index 9eb87e1e3479d72497ec72956b1de649b0ff735f..a407ce24f7bf72a2c19dbf1a83c455ee955baa2c 100644 --- a/example/02_gemm_bilinear/README.md +++ b/example/02_gemm_bilinear/README.md @@ -9,20 +9,3 @@ #arg11 to 12: alpha, beta ./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5 ``` -Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 1 times... -Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s -error: 0 -max_diff: 0, 558.5, 558.5 -``` diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index d1b820da7bb61a0d658b179f394dd2dfeac878af..18731e810e14db2b81a85264f19718a997e1cfa0 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index aca136f8016b4cd925b1a4151d1c34b4df57a73b..87812369bd1f1523cee5672f3f14624a94cc8113 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -17,6 +17,7 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" struct AlphaBetaAdd { @@ -175,6 +176,14 @@ int main(int argc, char* argv[]) exit(0); } + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index ab19f819e8668459d4f816bf1b0a7c219c7d8dc3..be47665a262ec6619816c249bdfdb96ba3c8ae16 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md index 08a55fb9a37f9f7823ae8882d4d36e9bd96ee6fd..7b0d003e59cb6d3c2e53a064b1845f7ab3c5c98a 100644 --- a/example/04_gemm_add_add_fastgelu/README.md +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -8,16 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> -``` diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index c5767982732f6edee76c35ba42c03e661599ef3b..8a295d14c4ec5f45aba77d31cd98ff78cecb6fb4 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -3,8 +3,7 @@ add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) diff --git a/example/09_convnd_fwd/README.md b/example/09_convnd_fwd/README.md index 9ab5fee549d44aefd22590d66681127a2816d9f9..22f90ea29af28a1e7b38030971ecf49a8b99b849 100644 --- a/example/09_convnd_fwd/README.md +++ b/example/09_convnd_fwd/README.md @@ -16,17 +16,3 @@ # , (ie RightPy, RightPx for 2D) ./bin/example_convnd_fwd_xdl 0 1 100 ``` - -Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) -``` -input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{432, 165888, 4} -arg.b_grid_desc_k0_n_k1_{432, 256, 4} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 100 times... -Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s -``` diff --git a/example/15_grouped_gemm/README.md b/example/15_grouped_gemm/README.md index c83b23e08cc7b923ce22316e0c882886d7437ef3..a2afe0f4b9ed003de276d6be71f9efab64525393 100644 --- a/example/15_grouped_gemm/README.md +++ b/example/15_grouped_gemm/README.md @@ -7,19 +7,3 @@ #arg3: run kernel # of times (>1) ./bin/example_grouped_gemm_xdl_fp16 0 1 5 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1} -gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1} -gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1} -gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1} -group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128} -group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256} -group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384} -group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512} -launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> -``` diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index d80c163e3f64db0e94429a1999905558061006fd..965a0e7e37836c06e3aceb29cd19976444c99707 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -63,7 +63,7 @@ using DeviceGemmInstance = //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on struct ProblemSize final @@ -92,9 +92,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto group_count = problem_size.group_count; using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using GemmDesc = ck::tensor_operation::device::GemmDesc; // GEMM shape - std::vector gemm_descs; + std::vector gemm_descs; std::vector ggemm_kargs; std::vector p_Cs; std::vector p_As; diff --git a/example/26_contraction/README.md b/example/26_contraction/README.md index c88d93cf83a411e2499cd0be80e524c42db339c6..acbfa84df108289b374c15ab6843a8cfdd99438f 100644 --- a/example/26_contraction/README.md +++ b/example/26_contraction/README.md @@ -7,14 +7,3 @@ #arg3: time kernel (0=no, 1=yes) ./bin/example_contraction_bilinear_xdl_fp32 1 1 1 ``` - -Result (MI100 @ dynammic freq, 46TFlops peak FP32) -``` -a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096} -c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1} -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4> -``` diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 2bbf430c4e239e6cb41b3d71e96c5c6903799712..f556be887f9824c097adbefb5637b985ebf349cd 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = 2, 4, 4, - true, + false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, - true, + false, 1, 1, S<1, 64, 1, 2>, diff --git a/example/30_grouped_conv_fwd_multiple_d/README.md b/example/30_grouped_conv_fwd_multiple_d/README.md index 7a0cb2d0e4ba47b79b8e524c5ac29eab08dad57b..1165634e1afb85049b688efbcf3882901c955a77 100644 --- a/example/30_grouped_conv_fwd_multiple_d/README.md +++ b/example/30_grouped_conv_fwd_multiple_d/README.md @@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims): ./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1 ``` -Result (MI100) -``` -in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192} -wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192} -bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0} -out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default> -``` diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 039d25029921491e7e67808e554b8cb3e6eb4745..ff873d26bcf927b9335fc9a2ded2d199ef3384e9 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = FP16; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp index 793324970e700748ac4b18e9fc429a24f70c007d..662a6f611b5f177251a4bafd32d4415fbc144992 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" // kernel data types using InKernelDataType = I8; @@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp index 2c7bacfc4ebc10be10ff6dcd5d1327820643a4ca..69ab5c5c0bbba9a4d7f7e08579fd94b71083c9a7 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp index d9ab645ee9b97c0932f8060ee480ccbb4fa85f3a..f5cedb14c9298f964763564f2adeef9645d1642a 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp index 4c92c5497fb66557a93acd50b71f965d90c5ec3c..41c6dff2dfefbff3abc5649f7518f13ef70a9d87 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +278,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, @@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_cross_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp index 12dcfcc36d9feb1af58598939b1893b6f9732276..955c25f0d143d9f41dead6d18b1ede00ecca99e8 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -28,6 +28,7 @@ Example is GQA-4 #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -299,4 +300,14 @@ using ReferenceGemm1Instance = #include "run_grouped_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp index 694a320a45f6ee8a0e35cf6ba330ddb1d05f0303..112be07c4963e37c3b6c65ccbc5b437ee33fd569 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_ #include "run_multi_query_attention_forward_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp index 8e037272b8345a96173cb280a58a8f217a4186dc..9ec1bc933f75d6af399f546a84ef221b910e4e22 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" template using S = ck::Sequence; @@ -71,7 +72,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial #define CK_MHA_USE_WAVE_1 #define CK_MHA_USE_WAVE_2 #define CK_MHA_USE_WAVE_4 -#define CK_MHA_USE_WAVE_8 +//#define CK_MHA_USE_WAVE_8 using DeviceMHAFactory = std::tuple< #ifdef CK_MHA_USE_WAVE_1 @@ -277,10 +278,10 @@ using DeviceMHAFactory = S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, // CShuffleBlockTransfer MN 1, 1, S<1, 64, 1, 2>, 8, - MaskingSpec>, + MaskingSpec> #endif #ifdef CK_MHA_USE_WAVE_8 - ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, @@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< #include "run_self_attention_wmma.inc" -int main(int argc, char* argv[]) { return run(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp index 5baa521501704a2ea603f5b975cd6b392285cfb5..3e3ae7edbde107fce0e7c1eeea3667472c7e7ce6 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -3,6 +3,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" #include "common.hpp" +#include "ck/host_utility/device_prop.hpp" using OutDataType = FP16; using WeiDataType = FP16; @@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat #include "run_grouped_conv_bwd_data_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/example/46_gemm_add_multiply/README.md b/example/46_gemm_add_multiply/README.md index ee5cdee3659e7b6c3ba9bc083c48f91fcd66e581..e2de4696f37f5fe4fdd1a0a9a1537ae2b3899ba1 100644 --- a/example/46_gemm_add_multiply/README.md +++ b/example/46_gemm_add_multiply/README.md @@ -8,19 +8,3 @@ #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" ./bin/example_gemm_add_multiply_dl_fp16 1 1 1 ``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} -d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} -d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.e_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1> -``` diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index 78f6832895a84940c26d08fc35c5975533e62bcc..e49056a948ca6dca6970dcf65e0c3c49089fc88a 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -1,7 +1,7 @@ add_custom_target(example_grouped_gemm_xdl_multi_abd) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) +add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eaabccdf2a773f2b298b7059a0a1811618ac88e5 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using EDataType = F16; +using ComputeDataType = F8; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::f8_t& a, const ck::f8_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = Multiply; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 1, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 1, 0}; + // B0[N0, N1, K0, K1] + std::vector b0_ns_ks_lengths{32, 64, 32, 64}; + std::vector b0_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // B1[N0, N1, K0, K1] + std::vector b1_ns_ks_lengths{32, 64, 32, 64}; + std::vector b1_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + + std::cout << "b0_ns_ks: " << b0_ns_ks.mDesc << std::endl; + std::cout << "b1_ns_ks: " << b1_ns_ks.mDesc << std::endl; + + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 2>{b0_ns_ks_lengths, b1_ns_ks_lengths}, + std::array, 2>{b0_ns_ks_strides, b1_ns_ks_strides}, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + PassThrough{}); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + + for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) + { + for(size_t n1 = 0; n1 < b_ns_ks.mDesc.GetLengths()[1]; ++n1) + { + for(size_t k0 = 0; k0 < b_ns_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < b_ns_ks.mDesc.GetLengths()[3]; ++k1) + { + b_element_op(b_ns_ks(n0, n1, k0, k1), + b0_ns_ks(n0, n1, k0, k1), + b1_ns_ks(n0, n1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = ref_op.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 5a35f9b6080df108bff3dc424921d8e2ea49125c..96d868de098f0c62a224d9c02e2fc712bb9eb963 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -1,4 +1,7 @@ add_subdirectory(binary) +add_subdirectory(convinvscale) +add_subdirectory(convscale) +add_subdirectory(convscale_relu) add_subdirectory(multi_AB) add_subdirectory(unary) diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..07f42075bd0a8a211f1c72cd7e60c6673469f846 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -0,0 +1,10 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convinvscale) + add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b2ebf84849e055d25da39759c5939387eaab27e --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_convinvscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbdfc72063dc4210c3ee429dd3e369398dc748ed --- /dev/null +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convinvscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvInvscale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convinvscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convinvscale/run_convnd_fwd_convinvscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9264da24a69896f309796154084098d61d5e43ac --- /dev/null +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -0,0 +1,20 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 ) + + add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8) + + add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) + + add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..978221f8e1093485cec0e98b140d801c1ca01335 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3; // 3 element-wise scale multipliers + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1c8c3a57f0cc5517bd3621c883fb64870bfab4c --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = InDataType; +using BComputeDataType = AComputeDataType; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8590d0620f4d280f51ed0d4bd8672a60da435624 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::bf8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7d69ccffc10056b6b582d6c7124e3d5d5928da2 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ab59e08a800435657517ecc6a6e1398820a39247 --- /dev/null +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScale; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convscale/run_convnd_fwd_convscale_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..95589cedcb406d55b9e91de6d7cdb1a06a87a01d --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -0,0 +1,11 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl_convscale_relu) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8 ) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2dacc20524d9b181251dfb1143108e7edba44bd --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_convscale_relu_common.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths, + const std::size_t& ds_size) +{ + // G * N * C * * (2 * K * + + // ) + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return G * N * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + (static_cast(2) * K * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()) + + ds_size); +} + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // random scale values + float scale_in = float(std::rand()) / float(RAND_MAX); + float scale_wei = float(std::rand()) / float(RAND_MAX); + float scale_out = float(std::rand()) / float(RAND_MAX); + + std::cout << std::endl; + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{}, + std::array, 0>{}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t ds_size = 3 + 1; // 3 element-wise scale multipliers + 1 element-wise relu + std::size_t flop = GetFlops(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size); + std::size_t num_btype = conv_param.GetInputByte() + + conv_param.GetWeightByte() + sizeof(float) + + sizeof(float) + sizeof(float) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); }); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, + out_host, + "Error: incorrect results!", + get_rtol(), + get_atol()); + } + + return true; +} diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..360349e7eca7526bd9e1af814213c540e116d51c --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_convscale_relu_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using DsDataType = ck::Tuple<>; +using OutDataType = ck::f8_t; +using AComputeDataType = ck::f8_t; +using BComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ConvScaleRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + DsLayout, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + DsDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeDataType, + BComputeDataType>; + +#include "run_convnd_fwd_convscale_relu_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..797146060216ead74f0b33c61e7c236d9bbb5772 --- /dev/null +++ b/example/62_convnd_activ/convscale_relu/run_convnd_fwd_convscale_relu_example.inc @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + // instantiate in and wei element ops, will + // instantiate out_element_op below for every iteration + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = + [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) { + constexpr ck::index_t ndim_spatial_value = ndim_spatial.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using DsLayout = decltype(ds_layout); + using OutLayout = decltype(out_layout); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd>( + do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op); + }; + + namespace ctc = ck::tensor_layout::convolution; + + if(conv_param.num_dim_spatial_ == 1) + { + return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{}); + } + else if(conv_param.num_dim_spatial_ == 2) + { + return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{}); + } + else if(conv_param.num_dim_spatial_ == 3) + { + return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{}); + } + + return true; +} diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt index 94ffb3661ceb2787b292fa6549c18391d2389c52..3470e9b9456f7361e1a05c07e99b00fe90990a31 100644 --- a/example/62_convnd_activ/unary/CMakeLists.txt +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -30,6 +30,16 @@ foreach(gpu IN LISTS GPU_TARGETS) # Elu add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16) + # Swish + add_example_executable(example_convnd_fwd_xdl_swish_fp16 convnd_fwd_xdl_swish_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_swish_fp16) + # PassThrough + add_example_executable(example_convnd_fwd_xdl_passthrough_fp16 convnd_fwd_xdl_passthrough_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_passthrough_fp16) + # Logistic + add_example_executable(example_convnd_fwd_xdl_logistic_fp16 convnd_fwd_xdl_logistic_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_logistic_fp16) + set(target 1) endif() endforeach() diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86811c2e969bf47c68b024bcea6f504a4b335e7f --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_logistic_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Logistic; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7167c4a84a5c9eae3084fef0c5af97becbe67348 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_passthrough_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65a2a5023eb26a542a75e4e4c0dde83824e139a3 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_swish_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Swish; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d968bdb9d1ff2d9a2359cf38cc1a2cbb2d3b97fc --- /dev/null +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) +add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fea43ffc35dbd17349de9b84f7b726fa5882047 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct AddAdd +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c + d0 + d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD, StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c584ff20cf27b68ef9702cdc86cc3961e8155b46 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using B0DataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5465adb7798285ebdc5117ba8835731f884ac513..45cfee4de93e4a2e2f79b9808eed86abd229d52f 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() @@ -154,7 +181,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) - IF(IS_DIRECTORY "${subdir}") + if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) ENDIF() ENDFOREACH() diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e31c96caaa2fe581304b044ea3bf0a5fc1198399..e30e9e793c918955075e10a79189e967e45db4e9 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,27 +1,47 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt + --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt ) -# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly -file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +add_custom_command( + OUTPUT ${FMHA_BWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding tile_example ${EXAMPLE_NAME}") +message("adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) +set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_BWD}") +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) + # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) @@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) +set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) +list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) +target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 5a428e4d4193f6b3136a0da70b1176bf4e738101..0bb540877283e373bb313c9dcbaef685b02000d7 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -30,27 +30,30 @@ args: -mode kernel mode. 0:batch, 1:group (default:0) -b batch size (default:2) -h num of head, for q (default:8) - -h_k num of head, for k/v, 0 means equal to h (default:0) + -h_k num of head, for k/v, -1 means equal to h (default:-1) if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary - -s_k seqlen_k, 0 means equal to s (default:0) + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) - -d_v head dim for v, 0 means equal to d (default:0) + -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) note when squant=1, this value will be modified by range_q/k - -range_q per-tensor quantization range of q. used if squant=1. (default:2) - -range_k per-tensor quantization range of k. used if squant=1. (default:2) - -range_v per-tensor quantization range of v. used if squant=1. (default:2) + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) - -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) - -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) - 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, - scale_o according to range_q, range_k, range_v, range_p, range_o + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) + -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) - -bias add bias or not (default:0) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) 't', top-left causal mask, 'b', bottom-r causal mask @@ -59,11 +62,14 @@ args: 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) - -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) - -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) + -init init method. ui, uniform random int, ni, normalized random int (default:uf) + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. @@ -85,6 +91,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov ### attention bias Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. +### alibi +alibi is supported + ### lse For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9dc656f6374594acb269949c839ad46e12759e8 --- /dev/null +++ b/example/ck_tile/01_fmha/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/codegen/__init__.py b/example/ck_tile/01_fmha/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/01_fmha/codegen/cmake_config.py b/example/ck_tile/01_fmha/codegen/cmake_config.py new file mode 100644 index 0000000000000000000000000000000000000000..03ebfd67021360c9f149d12232c790fa9ca06fb1 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d215f7f5f3f7a9e5bfe4dcc2088cf049b905ee --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +DTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp8" : "ck_tile::fp8_t" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/__init__.py b/example/ck_tile/01_fmha/codegen/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..0df115dc3df9b35a609f0ba9f764afd4d3f9b836 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -0,0 +1,613 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +BWD_DQDKDV_PIPELINE_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdDQDKDVPipelineQSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKSVR", +} + +BWD_DQDKDV_PIPELINE_ENUM_MAP = { + "ks_kts_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSKTSVR", + "qs_ks_vr_dos" : "ck_tile::BlockFmhaBwdPipelineEnum::QSKSVROGradS", + "ks_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KSVR", +} + +FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_bwd.hpp" +""" + +FMHA_BWD_DQ_DK_DV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>; +using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>; +using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; +using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape +// G0&G2 -> GSdP +// G1&G3 -> GdKV +// G4 -> GdQ +using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_dbias}, + false, + {F_dropout}, + false, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem< + typename FmhaBwdTypeConfig::QDataType, + typename FmhaBwdTypeConfig::KDataType, + typename FmhaBwdTypeConfig::VDataType, + typename FmhaBwdTypeConfig::GemmDataType, + typename FmhaBwdTypeConfig::LSEDataType, + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::BiasDataType, + typename FmhaBwdTypeConfig::RandValOutputDataType, + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::QGradDataType, + typename FmhaBwdTypeConfig::KGradDataType, + typename FmhaBwdTypeConfig::VGradDataType, + typename FmhaBwdTypeConfig::BiasGradDataType, + fmha_bwd_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_bwd_trait_{F_idx}>; + +using fmha_bwd_pipeline_{F_idx} = {F_pipeline}< + fmha_bwd_pipeline_problem_{F_idx}>; + +using fmha_bwd_dk_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, + false, false>>; + +using fmha_bwd_dv_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, + false, false>>; + +using fmha_bwd_dq_dk_dv_kernel_{F_idx} = + ck_tile::FmhaBwdDQDKDVKernel, + fmha_bwd_pipeline_{F_idx}, + fmha_bwd_dk_epilogue_{F_idx}, + fmha_bwd_dv_epilogue_{F_idx}>; + +using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dq_dk_dv_get_name_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp" +FMHA_BWD_API=""" +#include + +template +float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); +}} + +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; + r = fmha_bwd_(s, a); + return r; + }} +""" + +@dataclass +class FmhaBwdDQDKDVApiTrait: + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along k seqlen + bhdq : int # q head_dim + bhdv : int # v head_dim + mask : str + bias : str + dbias : str + dropout : str + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.pipeline}-{self.hdim}-{self.dtype}-{self.mode}-{self.mask}-{self.bias}-{self.dbias}-{self.dropout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 256 != 0' # BlockSize + else: # self.skpad == 'f' and skpad1 == 'f' + return f'a.seqlen_q % 256 == 0' # BlockSize + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +# GEMM0: Q@K=S^T +# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) +# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) +# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) +# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) +# Is it necessary to distinguish between K0~K4? +@dataclass +class FmhaBwdDQDKDVTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along gemm0 unroll(F_bhdq) + F_bk1 : int # tile size along gemm1 unroll(F_bm0) + F_bk2 : int # tile size along gemm2 unroll(F_bhdv) + F_bk3 : int # tile size along gemm3 unroll(F_bm0) + F_bk4 : int # tile size along gemm4 unroll(F_bn0) + F_bhdq : int # q head_dim + F_bhdv : int # v head_dim + F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2 + F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2 + F_rk0 : int # number of warps along gemm-k (not used) in gemm0/gemm2 + F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3 + F_rn1 : int # number of warps along q seqlen (block warps) in gemm1/gemm3 + F_rk1 : int # number of warps along gemm-k (not used) in gemm1/gemm3 + F_rm2 : int # number of warps along k seqlen (block warps) in gemm4 + F_rn2 : int # number of warps along q seqlen (block warps) in gemm4 + F_rk2 : int # number of warps along gemm-k (not used) in gemm4 + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ + f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_o{self.F_occupancy}" + +@dataclass +class FmhaBwdDQDKDVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_tile : FmhaBwdDQDKDVTileSize + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # + F_dbias : str # + F_dropout : str # + F_mask : str # value from MASK_MAP + F_mode : str # value from MODE_MAP + F_pipeline : str + mask_impl : str + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bk1 = self.F_tile.F_bk1, + F_bk2 = self.F_tile.F_bk2, + F_bk3 = self.F_tile.F_bk3, + F_bk4 = self.F_tile.F_bk4, + F_bhdq = self.F_tile.F_bhdq, + F_bhdv = self.F_tile.F_bhdv, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_rm2 = self.F_tile.F_rm2, + F_rn2 = self.F_tile.F_rn2, + F_rk2 = self.F_tile.F_rk2, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_spad = BOOL_MAP[self.F_spad], + F_skpad = BOOL_MAP[self.F_skpad], + F_dpad = BOOL_MAP[self.F_dpad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_bias = BIAS_MAP[self.F_bias], + F_dbias = BOOL_MAP[self.F_dbias], + F_dropout = BOOL_MAP[self.F_dropout], + F_occupancy = self.F_tile.F_occupancy, + F_mask = get_mask_map(self.mask_impl)[self.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], + F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_dbias == 't' : n += '_dbias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_dropout == 't' : n += '_dropout' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaBwdDQDKDVApiTrait: + return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bhdq=self.F_tile.F_bhdq, + bhdv=self.F_tile.F_bhdv, + mask=self.F_mask, + bias=self.F_bias, + dbias=self.F_dbias, + dropout=self.F_dropout, + spad=self.F_spad, + skpad=self.F_skpad, + dpad=self.F_dpad, + dvpad=self.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size & pipeline. +def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : [FmhaBwdDQDKDVTileSize(128, 128, 32, 32, 32, 32, 32, 32, 32, 1, 4, 1, 4, 1, 1, 4, 1, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "qs_ks_vr_dos"], + '128' : [FmhaBwdDQDKDVTileSize( 64, 128, 32, 32, 32, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 1), + "ks_vr"] + } + else: + return None + +def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: + # TODO: we don't support tuning yet, so pick up one value for pad + # support this in future + gen = list() + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + tile = d[hdim_str][0] + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, + F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, + F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, + F_pipeline=ppl, mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + if not cond: + continue + api_pool.register_dq_dk_dv_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_bwd_dot_do_o_trait_{F_idx} = ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, + {F_dvpad}, + {F_occupancy}>; + +using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< + typename FmhaBwdTypeConfig::ODataType, + typename FmhaBwdTypeConfig::OGradDataType, + typename FmhaBwdTypeConfig::DDataType, + /* BlockSize = */ 256, + {F_hdim}, + {F_mode}, + fmha_bwd_dot_do_o_trait_{F_idx}>; + +using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO< + fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; + +using fmha_bwd_dot_do_o_kernel_{F_idx} = + ck_tile::FmhaBwdOGradDotOKernel, + fmha_bwd_dot_do_o_{F_idx}>; + +using dot_do_o_trait_{F_idx} = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; + +#include + +template<> +float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +std::string fmha_bwd_dot_do_o_get_name_() +{{ + using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; + return k_::GetName(); +}} +""" + +@dataclass +class FmhaBwdOGradDotOKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_spad : str # true/false + F_dvpad : str # + F_mode : str # value from MODE_MAP + F_occupancy : int + + @property + def template(self) -> str: + return FMHA_BWD_KERNEL_HEADER + \ + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_spad = BOOL_MAP[self.F_spad], + F_dvpad = BOOL_MAP[self.F_dvpad], + F_mode = MODE_MAP[self.F_mode], + F_occupancy = self.F_occupancy) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" + if pn != '' : n += f'_{pn}' + return n + + @property + def filename(self) -> str: + return self.name + ".cpp" + +def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) + if d == None: + continue + for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): + hdim = int(hdim_str) + if (mode == "group" and spad == "f"): + continue + k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, + F_spad=spad, F_dvpad=dvpad, F_mode=mode, + F_occupancy=get_occupancy(dtype, hdim)) + gen.append(k) + + return gen + +def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + write_single_bwd_dot_do_o_kernel(kernel, output_dir) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) + write_bwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_bwd_dot_do_o_blobs() + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..137d3a2f70a6939bf3a4195a6d8d75cb070d4105 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -0,0 +1,501 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py new file mode 100644 index 0000000000000000000000000000000000000000..50939450951d5d8952bd251edc4d99dd331b9f13 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -0,0 +1,674 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + +from codegen.ops.fmha_fwd import ( + FmhaFwdTileSize, + FmhaFwdApiTrait, + FMHA_FWD_KERNEL_HEADER, + FMHA_FWD_API_PER_DTYPE, + FMHA_FWD_API_PER_HDIM_CASE, +) + + +FMHA_FWD_SPLITKV_PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", +} + +FMHA_FWD_SPLITKV_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_mask_{F_idx} = {F_mask}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape = ck_tile::TileFmhaShape; + +using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + kHasUnevenSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::OaccDataType, + fmha_shape, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait>; + +using fmha_pipeline = {F_pipeline}< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if constexpr({F_mode} == false) {{ // batch mode + if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + kernel_runner::run(s, a); + }} else {{ + kernel_runner::run(s, a); + }} + }} else {{ + kernel_runner::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_get_name_() +{{ + using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +namespace {{ +template +struct kernel_runner {{ +using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, + {F_dvpad}, + {F_lse}, + {F_squant}, + kLogMaxSplits, + {F_occupancy}>; + +using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + {F_hdim}, + {F_bm0}, + {F_bn1}, + {F_mode}, + fmha_trait>; + +using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< + fmha_pipeline_problem>; + +using fmha_epilogue = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel = + ck_tile::FmhaFwdSplitKVCombineKernel, + fmha_pipeline, + fmha_epilogue>; + +static void run(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel; + auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); +}} +}}; +}} + +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, + {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + +#include + +template<> +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if (a.num_splits <= 16) {{ + kernel_runner<4>::run(s, a); + }} else if (a.num_splits <= 32) {{ + kernel_runner<5>::run(s, a); + }} else if (a.num_splits <= 64) {{ + kernel_runner<6>::run(s, a); + }} else if (a.num_splits <= 128) {{ + kernel_runner<7>::run(s, a); + }} +}} + +template<> +std::string fmha_fwd_splitkv_combine_get_name_() +{{ + using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + return k_::GetName(); +}} +""" + +FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp" +FMHA_FWD_SPLITKV_API=""" +#include + +template +float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); +}} + +float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdSplitKVPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +@dataclass +class FmhaFwdSplitKVCombinePipeline: + tag : str + + F_spad : str # true/false + F_dvpad : str # + F_lse : str # + F_squant : str # + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}' + if pn != '' : n += f'_{pn}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdSplitKVApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdSplitKVCombineTileSize: + F_bm0 : int # tile size along q seqlen + F_bn1 : int # tile size along v head_dim + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdSplitKVKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdSplitKVPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +@dataclass +class FmhaFwdSplitKVCombineKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdSplitKVCombineTileSize + F_pipeline : FmhaFwdSplitKVCombinePipeline + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn1 = self.F_tile.F_bn1, + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_mode = MODE_MAP[self.F_mode]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + +def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), + '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), + '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + } + else: + return None + +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: + Pipeline = FmhaFwdSplitKVPipeline + Kernel = FmhaFwdSplitKVKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + # splitkv kernel donot support dropout + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): + if hdim == 256: + # if True: + pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1: + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdSplitKVApiPool(mask_impl) + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: + Pipeline = FmhaFwdSplitKVCombinePipeline + Kernel = FmhaFwdSplitKVCombineKernel + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(Pipeline('unused', spad, dvpad, lse, squant)) + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant)) + else: + assert False + return pipelines + + gen = list() + + for dtype in DTYPE_MAP.keys(): + d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = Kernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + gen.append(k) + + return gen + +def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None: + file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME + file_path.write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_fwd_splitkv_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: + with file_path.open('a') as f: + kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") \ No newline at end of file diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1249b5eda93c048cc231448484edc9bd31e6fd9 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -0,0 +1,932 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_bwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("dbias", "0", "output bias gradient or not") + .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k < 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + if(hdim_q % 2 != 0 || hdim_v % 2 != 0) + { + std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl; + return false; + } + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale = arg_parser.get_float("scale"); + if(scale == .0f) + scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + bool use_dbias = arg_parser.get_bool("dbias"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) + { + std::cerr << "dbias only exists when bias type is elementwise" << std::endl; + return false; + } + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaBwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using GemmDataType = typename TypeConfig::GemmDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using AccDataType = typename TypeConfig::AccDataType; + using DDataType = typename TypeConfig::DDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using ODataType = typename TypeConfig::ODataType; + using OGradDataType = typename TypeConfig::OGradDataType; + using QGradDataType = typename TypeConfig::QGradDataType; + using KGradDataType = typename TypeConfig::KGradDataType; + using VGradDataType = typename TypeConfig::VGradDataType; + using BiasGradDataType = typename TypeConfig::BiasGradDataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(3) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T + static_cast(2) * static_cast(2) * + real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * real_seqlen_k * hdim_v + + sizeof(ODataType) * real_seqlen_q * hdim_v + + sizeof(OGradDataType) * real_seqlen_q * hdim_v + + sizeof(QGradDataType) * real_seqlen_q * hdim_q + + sizeof(KGradDataType) * real_seqlen_k * hdim_q + + sizeof(VGradDataType) * real_seqlen_k * hdim_v + + sizeof(LSEDataType) * real_seqlen_q); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor lse_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor d_host( + std::array{batch, nhead, max_seqlen_q}); + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor dq_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor dk_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor dv_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v)); + ck_tile::HostTensor do_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + ck_tile::HostTensor dbias_host( + use_dbias + ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + if(init_method == 0) + { + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(do_host); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 2) + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + ck_tile::FillTrigValue{}(do_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + do_buf.ToDevice(do_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias + << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask + << std::flush; + + auto fmha_traits = fmha_bwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + mask.type, + bias.type, + use_dbias, + p_drop > 0.0f}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v); + const ck_tile::index_t stride_bias = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_bias = 0; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; + const ck_tile::index_t nhead_stride_dbias = + (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_bias = 0; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); + const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); + + return fmha_bwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + do_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + dq_buf.GetDeviceBuffer(), + dk_buf.GetDeviceBuffer(), + dv_buf.GetDeviceBuffer(), + dbias_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, + stride_o, + stride_randval, + stride_do, + stride_dk, + stride_dv, + stride_dbias, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + nhead_stride_randval, + nhead_stride_do, + nhead_stride_lsed, + nhead_stride_dbias, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + batch_stride_randval, + batch_stride_do, + batch_stride_lsed, + batch_stride_dk, + batch_stride_dv, + batch_stride_dbias, + mask.left, + mask.right, + static_cast(mask.type), + p_drop, + p_undrop, + s_randval, + {drop_seed, drop_offset}}; + }(); + + float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + bool pass = true; + + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; + + randval_buf.FromDevice(randval_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + ck_tile::HostTensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + ck_tile::HostTensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + // clang-format on + + // reference + // S = scale * Q * K^T + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + + if(p_drop > 0) + { + p_hp_host_ref.ForEach( + [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + else + { + p_hp_host_ref.ForEach([&](auto& self, auto idx) { + p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + // O = P * V + ck_tile::reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + + // clang-format off + // permute + if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); + else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); + + lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); + // clang-format on + + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } + } + + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + fmha_bwd(fmha_traits, fmha_args, stream_config_v); + + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + ck_tile::HostTensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + ck_tile::HostTensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + + // clang-format off + if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); + else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); + // clang-format on + + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + ck_tile::reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + + if(p_drop > 0) + { + ck_tile::reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } + + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + auto idx_gmo = idx_gmn; + idx_gmo[2] = o; + do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * + ck_tile::type_convert(o_host_refs[wb](idx_gmo)); + } + self(idx_gmn) = ck_tile::type_convert( + p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); + }); + + if(use_dbias) + { + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + } + + ds_hp_host_ref.ForEach([&](auto& self, auto idx) { + ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); + }); + + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + ck_tile::reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + ck_tile::reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + ck_tile::reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + + // clang-format off + // permute + if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + + if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); }); + else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); }); + + if(use_dbias) + { + if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + } + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); + + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0c6b468951881d4838d07f832a3c27877768cc70 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include "bias.hpp" +#include + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_lsed, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dq_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + // TODO: padding check is inside this api +}; +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8ca4ff9337728a4b7b31135d80c6c77455fe097e..28f790573415744b8ee347aa586f78f60e8ac784 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -41,16 +41,23 @@ auto create_args(int argc, char* argv[]) .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") .insert("h_k", - "0", - "num of head, for k/v, 0 means equal to h\n" + "-1", + "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") - .insert("s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") - .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") - .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" @@ -60,18 +67,24 @@ auto create_args(int argc, char* argv[]) .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") - .insert( - "squant", - "0", - "if using static quantization fusion or not. 0: original flow(not prefered)\n" - "1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" - "scale_o according to range_q, range_k, range_v, range_p, range_o") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") .insert("iperm", "1", "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("mask", "0", @@ -88,12 +101,22 @@ auto create_args(int argc, char* argv[]) .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") - .insert( - "init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") + .insert("init", + "uf", + "init method. ui, uniform random int, ni, normalized random int\n" + "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " + "quantization") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -103,7 +126,7 @@ auto create_args(int argc, char* argv[]) // different threshold for different dtype template -auto get_elimit(int /*init_method*/) +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; double atol = 1e-3; @@ -111,26 +134,17 @@ auto get_elimit(int /*init_method*/) } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == 0) - { - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); - } - else - { - double rtol = 3e-3; - double atol = 3e-3; - return ck_tile::make_tuple(rtol, atol); - } + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); } template <> -auto get_elimit(int init_method) +auto get_elimit(std::string init_method) { - if(init_method == 0) + if(init_method == "ui" || init_method == "ni") { unsigned max_rounding_point_distance = 0; double atol = 2e-3; @@ -144,6 +158,106 @@ auto get_elimit(int init_method) } } +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + efficiency.push_back(0.f); + } + else + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + const int kN1 = hdim_v; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + } + + return num_splits; +} + +float fmha_fwd_dispatch(fmha_fwd_traits traits, + fmha_fwd_args args, + const ck_tile::stream_config& config) +{ + if(1 < args.num_splits) + { + return fmha_fwd_splitkv(traits, args, config); + } + else + { + return fmha_fwd(traits, args, config); + } +} + template bool run(const ck_tile::ArgParser& arg_parser) { @@ -153,7 +267,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - if(nhead_k == 0) + if(nhead_k < 0) nhead_k = nhead; if(nhead % nhead_k != 0) @@ -162,13 +276,23 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k == 0) - seqlen_k = seqlen_q; + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad")); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v == 0) + if(hdim_v < 0) hdim_v = hdim_q; bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim @@ -178,15 +302,18 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? - bool squant = arg_parser.get_bool("squant"); - if constexpr(!std::is_same_v) - { - if(squant) + std::string squant_str = arg_parser.get_str("squant"); + bool squant = [&]() { + if(squant_str == "auto") { - std::cerr << "static quantization only support fp8 for now" << std::endl; - return false; + if(data_type == "fp8") + return true; + else + return false; } - } + else + return atoi(squant_str.c_str()) != 0 ? true : false; + }(); float range_q = arg_parser.get_float("range_q"); float range_k = arg_parser.get_float("range_k"); @@ -208,45 +335,70 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::string vlayout = arg_parser.get_str("vlayout"); - bool use_bias = arg_parser.get_bool("bias"); bool lse = arg_parser.get_bool("lse"); - mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore - int init_method = arg_parser.get_int("init"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation) + { + s_randval = true; + } + + std::string init_method = arg_parser.get_str("init"); std::optional seed = arg_parser.get_uint32("seed"); if(*seed == 0) { seed.reset(); } + int num_splits = arg_parser.get_int("num_splits"); + int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - ck_tile::stream_config stream_config{ - nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; - const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); using TypeConfig = FmhaFwdTypeConfig; - using QDataType = typename TypeConfig::QDataType; - using KDataType = typename TypeConfig::KDataType; - using VDataType = typename TypeConfig::VDataType; - using BiasDataType = typename TypeConfig::BiasDataType; - using LSEDataType = typename TypeConfig::LSEDataType; - using SaccDataType = typename TypeConfig::SaccDataType; - using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; - using PDataType = typename TypeConfig::PDataType; - using OaccDataType = typename TypeConfig::OaccDataType; - using ODataType = typename TypeConfig::ODataType; + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -258,6 +410,11 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_q = real_seqlen_q; } + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); @@ -268,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return false; + } + auto get_lengths = [&](bool permute, ck_tile::index_t b /*batch*/, ck_tile::index_t h /*nhead*/, @@ -284,9 +453,11 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -295,42 +466,75 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor v_host( is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); - // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host - // will not be used for verification at all (but will be copied to device anyway). + ck_tile::HostTensor bias_host( - use_bias + bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); - // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + + ck_tile::HostTensor lse_acc_host( + 1 < num_splits ? std::array{num_splits, batch, nhead, max_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits + ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // self define lse data layout as [batch, nhead, max_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{batch, nhead, max_seqlen_q} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); - if(init_method == 0) + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + + if(init_method == "ui" || init_method == "0") { - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); } - else if(init_method == 1) + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "uf" || init_method == "1") { ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); } - else if(init_method == 2) + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } + else if(init_method == "tf" || init_method == "2") { ck_tile::FillTrigValue{}(q_host); ck_tile::FillTrigValue{}(k_host); ck_tile::FillTrigValue{}(v_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == 3) // suitable for fp8 quantization + else if(init_method == "ufq" || init_method == "uf:q" || + init_method == "3") // suitable for fp8 quantization { ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); @@ -341,22 +545,48 @@ bool run(const ck_tile::ArgParser& arg_parser) // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off auto layout_str = [&](bool permute){ @@ -371,10 +601,17 @@ bool run(const ck_tile::ArgParser& arg_parser) const std::string prec = arg_parser.get_str("prec"); std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s - << ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout; + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + std::cout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -382,8 +619,9 @@ bool run(const ck_tile::ArgParser& arg_parser) mode == mode_enum::group, is_v_rowmajor, mask.type, - use_bias, + bias.type, lse, + p_drop > 0.0f, squant}; auto p_compute_element_func = [&]() { @@ -401,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return ck_tile::identity{}; }(); - auto fmha_args = [&]() { + auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -415,8 +653,10 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = hdim_v; + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); @@ -428,25 +668,38 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); - const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = max_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; + const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); - const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), + randval_buf.GetDeviceBuffer(), + lse_acc_buf.GetDeviceBuffer(), + o_acc_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), - nullptr, + k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), shape_seqlen_q, shape_seqlen_k, batch, @@ -455,32 +708,47 @@ bool run(const ck_tile::ArgParser& arg_parser) hdim_v, nhead, nhead_k, + num_splits, scale_s, scale_p, scale_o, stride_q, stride_k, stride_v, - stride_bias, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, + stride_randval, + stride_o_acc, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, nhead_stride_bias, + nhead_stride_randval, nhead_stride_lse, + nhead_stride_lse_acc, + nhead_stride_o_acc, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_bias, + batch_stride_randval, batch_stride_lse, + batch_stride_lse_acc, + batch_stride_o_acc, batch_stride_o, + split_stride_lse_acc, + split_stride_o_acc, mask.left, mask.right, - static_cast(mask.type)}; + static_cast(mask.type), + p_drop, + s_randval, + {drop_seed, drop_offset}}; }(); - float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); if(ave_time < 0) { @@ -504,6 +772,11 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; bool pass = true; @@ -515,7 +788,10 @@ bool run(const ck_tile::ArgParser& arg_parser) // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; @@ -564,8 +840,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); - if(use_bias) + if(bias.type == bias_enum::elementwise_bias) { + // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) @@ -582,6 +859,52 @@ bool run(const ck_tile::ArgParser& arg_parser) SMPLComputeDataType>( s_host_ref, bias_host_ref, s_host_ref); } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } if(mask.type == mask_enum::no_mask) { @@ -629,6 +952,17 @@ bool run(const ck_tile::ArgParser& arg_parser) s_host_ref, p_host_ref, p_compute_element_func); } + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + ck_tile::reference_batched_gemm( p_host_ref, v_host_ref, @@ -662,18 +996,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b, idx[0], idx[1] + query_offset); - }); + lse_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); - bool lse_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); - pass &= lse_pass; + pass &= cur_pass; if(!cur_pass) { std::cerr << "LSE mismatch found at batch: " << wb << std::endl diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9a82ab6b7942e4de8ad583eabcbcb0eb98382d12..ee932ce5d998488660da66d74f0d7aaf38637c7e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include "bias.hpp" #include template @@ -16,61 +17,65 @@ struct FmhaFwdTypeConfig; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; - using BiasDataType = ck_tile::half_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; - using BiasDataType = ck_tile::bf16_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; }; template <> struct FmhaFwdTypeConfig { - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; }; struct FmhaMasks @@ -86,7 +91,10 @@ struct fmha_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* bias_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -100,29 +108,43 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; float scale_s; float scale_p; float scale_o; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; }; template @@ -137,6 +159,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, @@ -144,6 +167,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.seqlen_k_ptr, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -152,16 +176,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_lse, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments @@ -169,12 +199,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, args.hdim_v, + args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, args.scale_p, @@ -183,22 +215,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_k, args.stride_v, args.stride_bias, + args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, + args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, + args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } }(); @@ -206,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = + Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.max_seqlen_q, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template ; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; @@ -252,6 +462,40 @@ struct fmha_fwd_traits_ template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -261,9 +505,11 @@ struct fmha_fwd_traits bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool has_bias; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; + bool has_dropout; bool do_fp8_static_quant; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 56d699e5fef78c0bfe658bd098e9067faf58ca30..27347b4476661d0e9e644e9b49fa95a079fec75d 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -3,563 +3,62 @@ # generate kernel instances to speed up compilation import argparse -import itertools +from enum import IntEnum from pathlib import Path -from typing import List, Optional, Tuple -from dataclasses import dataclass -import copy -import fnmatch +from typing import List, Optional -DTYPE_MAP = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp8" : "ck_tile::fp8_t" -} - -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} - -MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" -} - -MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", -} - -MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" -} - -MODE_MAP = { - "batch" : "false", - "group" : "true" -} - -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} - -PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", -} - -PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", -} - -BOOL_MAP = { - "t" : "true", - "f" : "false" -} - -DIRECTIONS = ["fwd"] -GEN_DIR = "" # in Cmake, have to generate files in same folder - -FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n -// auto generated by generate.py -#include "fmha_fwd.hpp" -""" - -FMHA_FWD_KERNEL_BODY=""" -using fmha_dtype_{F_idx} = {F_dtype}; - -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; - -using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; - -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, - {F_dpad}, - {F_dvpad}, - {F_bias}, - {F_lse}, - {F_squant}, - {F_occupancy}>; -using fmha_mask_{F_idx} = {F_mask}; - -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, - fmha_shape_{F_idx}, - {F_mode}, - fmha_mask_{F_idx}, - fmha_trait_{F_idx}>; - -using fmha_pipeline_{F_idx} = {F_pipeline}< - fmha_pipeline_problem_{F_idx}>; +from codegen.cmake_config import * +from codegen.ops import ( + fmha_fwd, + fmha_fwd_splitkv, + fmha_bwd +) -using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - -#include - -template<> -float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) -{{ - using k_ = fmha_kernel_{F_idx}; - if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); -}} -""" - -FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" -FMHA_FWD_API=""" -float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ - float r = -1; -{F_dispatch} - return r; -}} -""" - -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} -""" -MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", +handlers = { + 'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), + 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), + 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), } -MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", -} - -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_(s, a); - }} -""" - -def get_mask_map(mask : str): - if mask == "generic": - return MASK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_MAP - else: - assert False - return None - -def get_mask_check_map(mask : str): - if mask == "generic": - return MASK_CHECK_MAP - elif mask == "simplified": - return MASK_SIMPLIFIED_CHECK_MAP - else: - assert False - return None - -@dataclass -class FmhaFwdApiTrait: - pipeline_tag : str - # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0blen : int - vlayout : str - mask : str - bias : str # true/false - lse : str # - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - - @property - def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' - - @property - def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - - @property - def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - else: assert False - - @property - def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' - else: assert False - - @property - def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' - else: assert False - -@dataclass -class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_squant : str # - F_mask : str # value from MASK_MAP - - @property - def name(self) -> str: - def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n - return n - pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - if self.F_bias == 't' : n += '_bias' - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - if self.F_lse == 't' : n += '_lse' - if self.F_squant == 't' : n += '_squant' - return n - -class FmhaFwdApiPool: - def __init__(self, mask_impl): - self.pool = dict() - self.mask_impl = mask_impl - - def register_traits(self, trait : FmhaFwdApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.pool.keys(): - self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() - - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) - -@dataclass -class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along qk seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm : int # number of warps along q seqlen (block warps) - F_rn : int # number of warps along k seqlen(not used) - F_rk : int # number of warps along gemm-k(not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - @property - def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ - f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") - -@dataclass -class FmhaFwdKernel: - direction : str - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str - - @property - def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BOOL_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) - - @property - def name(self) -> str: - # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ - self.F_tile.name + '_' + self.F_pipeline.name - - @property - def filename(self) -> str: - return self.name + ".cpp" - - def api_trait(self) -> FmhaFwdApiTrait: - return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) - -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: - if direction == 'fwd': - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) - } - else: - return None - else: - return None - -def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: - # this function will populate a list possible pipelines - # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' - pipelines = [] - if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) - if receipt == 1: - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: - # no need lse kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) - else: - assert False - return pipelines - - gen = list() - api_pool = FmhaFwdApiPool(mask_impl) - - for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): - d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) - if d == None: - continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): - if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - k = FmhaFwdKernel(direction=direction, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != None: - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - api_pool.register_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - -def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - write_single_kernel(kernel, output_dir) - write_api(api_pool, output_dir) + + for api in api_list: + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) - with file_path.open('a') as f: - _, kernels = get_blobs(kernel_filter, receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + + for api in api_list: + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", - description="gen api for CK fmha kernel", + description="gen API for CK fmha kernel", + ) + parser.add_argument( + "-d", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", + default='fwd', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." ) parser.add_argument( "-o", @@ -595,11 +94,13 @@ if __name__ == "__main__": default=0, required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim" + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration" ) args = parser.parse_args() + api_list = args.direction.split(',') if args.list_blobs is not None: - list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) \ No newline at end of file diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 56fc8b8b1d25039d0f4ee9de46ac524468d21750..c77b700b16c2538c3490305e62bf06ad1eb659d3 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -149,11 +149,9 @@ struct mask_info return tmp; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } }; - -inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) -{ - mi.serialize(os); - return os; -} diff --git a/example/ck_tile/01_fmha/script/benchmark_bwd.sh b/example/ck_tile/01_fmha/script/benchmark_bwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..7591f5442a6908f9441085bf569ebacbb4271269 --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_bwd.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh similarity index 100% rename from example/ck_tile/01_fmha/script/benchmark.sh rename to example/ck_tile/01_fmha/script/benchmark_fwd.sh diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh deleted file mode 100755 index 4dd5c2ae12d67cb00e22dc8b85d12e8923da5c63..0000000000000000000000000000000000000000 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/sh -# TODO: run this script from CK root -BUILD=build -EXE=$BUILD/bin/tile_example_fmha_fwd -KNAME=1 - -export CK_WARMUP=0 -export CK_REPEAT=1 - -COMMON_ARGS='-v=1 -warmup=0 -repeat=1' -# mode=0 -# export HIP_VISIBLE_DEVICES=4 - -for prec in "fp16" "bf16" ; do -for mode in 1 0 ; do -for perm in 0 1 ; do -for vlayout in "r" "c" ; do -for hdim in 32 64 128 256 ; do -for lse in 0 1 ; do -for bias in 0 1 ; do - -# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS - -done -done -done -done -done -done -done - -for perm in 0 1 ; do -for bias in 0 1 ; do -for b in 1 2 ; do -$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS -done -done -done diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..d6830aa2ec2ba008585411680d19d22824eb0ed4 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -0,0 +1,34 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_bwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1' +set -x +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 32 64 128 ; do +for mode in 0 1 ; do +for bias in "n" "e" "a"; do +for dbias in 0 1 ; do +for p_drop in 0.0 0.2; do + +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done +set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh new file mode 100755 index 0000000000000000000000000000000000000000..779e8d09ee85b3186e4b6d4bfab71de253a41266 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -0,0 +1,53 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_fwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 +set -x +for prec in "fp16" "bf16" ; do +for mode in 1 0 ; do +for perm in 0 1 ; do +for vlayout in "r" "c" ; do +for hdim in 32 64 128 256 ; do +for lse in 0 1 ; do +for bias in "n" "e" "a" ; do +for p_drop in 0.0 0.2; do + +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done +done + + +for perm in 0 1 ; do +for bias in "n" "e" "a" ; do +for b in 1 2 ; do +for hdim in 64 128 256 ; do +$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS +done +done +done +done +set +x diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index e10ae617dc02a21270bc160072ff4dbbe7530aac..737efd82568788e2bd5004c583a29a06f80d9978 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -4,12 +4,14 @@ #pragma once #include +#include #include #include #include #include #include #include +#include #include "ck_tile/core/container/span.hpp" @@ -37,12 +39,14 @@ std::vector to_seqstarts(ck_tile::span seqlens) std::vector generate_seqlens(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, // if not negative, clamp max std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens(count, seqlens_sum); + std::vector seqlens( + count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); if(mode == mode_enum::group && 1 < count) { @@ -55,7 +59,7 @@ std::vector generate_seqlens(mode_enum mode, std::uniform_int_distribution step_dist(1, count - 1); auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); // make sure each elements of seqlens is always greater than 0 @@ -66,6 +70,11 @@ std::vector generate_seqlens(mode_enum mode, const size_type to_increase = (to_decrease + next_step()) % count; + if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) + { + continue; + } + --seqlens[to_decrease]; ++seqlens[to_increase]; } @@ -76,10 +85,91 @@ std::vector generate_seqlens(mode_enum mode, std::vector generate_seqstarts(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, std::optional seed = std::nullopt) { - return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + auto s_q = std::vector(batch, q); + auto s_k = std::vector(batch, k < 0 ? q : k); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); + auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ } int env_get_int(const char* var_name, int default_int) @@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) char* v = getenv(var_name); int r = default_int; if(v) - r = atoi(v); + r = std::atoi(v); return r; } diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bac5f45cd38988bdc037f676b2681d1b062b4ee7 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -0,0 +1,4 @@ +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..433dad04e6825a7355b47a18eb52689f1d1a374f --- /dev/null +++ b/example/ck_tile/02_layernorm2d/README.md @@ -0,0 +1,22 @@ +# Layernorm2D forward + +This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_layernorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_layernorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9cbd2861044ddcd33fe237e83d0e40241bd4ff80 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -0,0 +1,191 @@ +#include "ck_tile/host.hpp" +#include "layernorm2d_fwd.hpp" +#include + +// Host API implementation +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + using thread_tile = ck_tile::sequence<4, 4>; + using warp_tile = ck_tile::sequence<8, 128>; + using block_tile = ck_tile::sequence<32, 128>; + + using Shape = ck_tile::TileLayernorm2dShape; + + using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; + + using Kernel = ck_tile::Layernorm2dFwd; + + auto kargs = Kernel::MakeKargs( + a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); + + const dim3 grids = Kernel::GridSize(a.M); + constexpr dim3 blocks = Kernel::BlockSize(); + + constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + } + + return 0; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "m dimension") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +int main(int argc, char* argv[]) +{ + + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; +#ifdef SAVE_MEAN_INV_STD + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; +#else + using MeanDataType = ck_tile::null_type; + using InvStdDataType = ck_tile::null_type; +#endif + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({M, N}); + ck_tile::HostTensor gamma_host({N}); + ck_tile::HostTensor beta_host({N}); + + ck_tile::HostTensor y_host_ref({M, N}); + ck_tile::HostTensor y_host_dev({M, N}); + + ck_tile::HostTensor mean_host_ref({M}); + ck_tile::HostTensor invStd_host_ref({M}); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::HostTensor mean_host_dev({M}); + ck_tile::HostTensor invStd_host_dev({M}); +#endif + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + +#ifdef SAVE_MEAN_INV_STD + ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); +#endif + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + beta_buf.ToDevice(beta_host.data()); + + layernorm2d_fwd_traits traits{data_type}; + + layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + mean_buf.GetDeviceBuffer(), + invStd_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + epsilon, + M, + N}; + + float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << "[" << data_type << "]" + << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" + << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + pass = ck_tile::check_err(y_host_dev, y_host_ref); + +#ifdef SAVE_MEAN_INV_STD + mean_buf.FromDevice(mean_host_dev.data()); + pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + + invStd_buf.FromDevice(invStd_host_dev.data()); + pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); +#endif + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + + std::cout << std::endl << std::flush; + + return !pass; +} diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d1aac0994c592bbbdf675cd0fcec30dac42b695 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/layernorm2d.hpp" +#include + +struct layernorm2d_fwd_traits +{ + std::string data_type; +}; + +struct layernorm2d_fwd_args +{ + const void* p_x; + const void* p_gamma; + const void* p_beta; + void* p_y; + void* p_mean; + void* p_invStd; + float epsilon; + ck_tile::index_t M; + ck_tile::index_t N; +}; + +// host API +float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index d2b086e043211b2d600dc9bf1cd1d9d45d43888d..995d193f10f7e601af57e96d845fc8c79eccff3f 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(AFTER ) add_subdirectory(01_fmha) +add_subdirectory(02_layernorm2d) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 0bda8b7590a436f48878d9419302377334e5e334..9528a30b4b2109385414ca18fef9d1ce3b15be68 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,12 +4,19 @@ #pragma once #include "ck/config.h" +#include "ck/utility/env.hpp" #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" #endif +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + +// to do: add various levels of logging with CK_LOG_LEVEL + #define CK_TIME_KERNEL 1 // constant address space for kernel parameter @@ -62,6 +69,9 @@ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code @@ -70,7 +80,7 @@ #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -82,7 +92,7 @@ #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx12__) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 @@ -103,13 +113,6 @@ #define CK_USE_AMD_MFMA_GFX940 #endif -// WMMA instruction -#ifndef __HIP_DEVICE_COMPILE__ // for host code -#define CK_USE_AMD_WMMA -#elif defined(__gfx11__) // for GPU code -#define CK_USE_AMD_WMMA -#endif - // buffer load #define CK_USE_AMD_BUFFER_LOAD 1 @@ -148,7 +151,7 @@ #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 // LDS direct loads using inline assembly -#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 // set stochastic rounding as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 1 @@ -225,17 +228,17 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 -// flag to enable (1) or disable (0) the debugging output in some kernels -#define DEBUG_LOG 0 - // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 #else -// enable only on MI200 +// enable only for gfx90a #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX +// set flag to 1 to build deprecated instances +#define CK_BUILD_DEPRECATED 1 + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 13e5268752ea149a395fdd55e0f730413a9e2170..83af2efe88ec084da4be46756ad142b2456b8325 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -65,23 +65,28 @@ inline bool is_lds_direct_load_supported() ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; } -inline bool is_navi1_supported() +inline bool is_gfx101_supported() { return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || ck::get_device_name() == "gfx1012"; } -inline bool is_navi2_supported() +inline bool is_gfx103_supported() { return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } -inline bool is_navi3_supported() +inline bool is_gfx11_supported() { return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; } +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + } // namespace ck diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 805fb571fba0f5cd8bae83697be6c65025b52428..9d9974d49db73861a281ceb6c9eb6a5fa7ca2bc7 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -5,6 +5,7 @@ #include #include +#include #include "ck/ck.hpp" #include "ck/stream_config.hpp" @@ -103,35 +104,41 @@ inline void flush_icache() hip_check_error(hipGetLastError()); } // if TimePrePress == false, return time does not include preprocess's time -template +template float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, - Args& args) + GemmArgs& gemm_args, + Args... args) { #if CK_TIME_KERNEL #define MEDIAN 1 if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); } @@ -140,9 +147,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { return 0.0; } -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } #if MEDIAN std::set times; @@ -169,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, preprocess(); } // run real kernel - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); // end real kernel @@ -183,13 +191,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, total_time += cur_time; #endif -#if DEBUG_LOG - std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; - printf("args.p_a_grid: %p, args.p_b_grid:%p\n", - static_cast(args.p_a_grid), - static_cast(args.p_b_grid)); -#endif + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); + } } #if MEDIAN @@ -212,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, else { preprocess(); - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; } #else - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 1ed7686e7fc0b10daf60f8257e32115d6c2dc9ff..a616433ac9ce097ca7ba5245b108b01a112d477e 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,18 +20,19 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { @@ -40,9 +41,10 @@ float launch_and_time_kernel(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); @@ -93,18 +95,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { -#if DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up %d times\n", stream_config.cold_niters_); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } // warm up preprocess(); for(int i = 0; i < stream_config.cold_niters_; ++i) @@ -114,9 +117,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; -#if DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index f68473c292a9c972e32c670f71846cd88bbc9882..c152cbfb1ea1eed1a77e1172b4dd3c398ebca98b 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1952,7 +1952,7 @@ struct Modulo } }; -template +template struct Xor { using LowerIndex = MultiIndex<2>; @@ -1981,8 +1981,15 @@ struct Xor idx_low(Number<0>{}) = idx_up[Number<0>{}]; - idx_low(Number<1>{}) = - idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + if constexpr(ApplyModulo) + { + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + else + { + idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}]; + } } template {modulus, up_length}; } +template +__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} + template __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) { - return Xor{low_lengths}; + return Xor{low_lengths}; } } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index d62ed4b15dd3a6ea078ef7ed0b70e818a30af611..f03427a7ead1cfef5450f04dfa0dda7b38e2763f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -300,9 +300,9 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - dpp_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + dpp_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 5d137e67e60866fb5463579669cc112a45a95b96..1121cc45509f39752ebff4c7973642483606a8b2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -613,7 +613,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( + xdlops_gemm.Run( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -681,7 +681,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( + xdlops_gemm.Run( a_thread_vec.template AsType(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -749,10 +749,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -808,10 +807,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -840,10 +838,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -901,10 +898,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -939,10 +935,9 @@ struct BlockwiseGemmXdlops_pipeline_v4 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 0a7ad545b36ec79603cfe9c4370c2a74b889c66b..f597573dc2a46660ba6f10f8a61050a502f0e2ad 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; @@ -259,7 +259,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -319,10 +319,9 @@ struct BlockwiseGemmXdlops_pipeline_v1(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -446,12 +445,12 @@ struct BlockwiseGemmXdlops_pipeline_v1 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; @@ -584,7 +583,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -668,7 +667,7 @@ struct BlockwiseGemmXdlops_pipeline_v1(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 9acfd00858c66232a109a111dcdcc18becf95135..711c47854adad7b2880718f69ec3febe05984bb4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) { @@ -303,7 +303,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -374,7 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -428,10 +428,9 @@ struct BlockwiseGemmXdlops_pipeline_v2(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -480,10 +479,9 @@ struct BlockwiseGemmXdlops_pipeline_v2(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -646,12 +644,12 @@ struct BlockwiseGemmXdlops_pipeline_v2 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % PrefetchStages == 1) { @@ -821,7 +819,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -914,7 +912,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -990,7 +988,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -1066,7 +1064,7 @@ struct BlockwiseGemmXdlops_pipeline_v2(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 3acfe0daad97cfb4051cc6e33c5bb55fb1332bcc..d47318dd0136e3679a69abc26431a8e4a8bd40e1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { ignore = num_loop; return TailNumber::Full; @@ -381,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_v3(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -440,10 +440,9 @@ struct BlockwiseGemmXdlops_pipeline_v3(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 75569150bdfd7f92c560b23731faf7cc9ff42160..bd5a1bedf537c3a4a31d53cae5f2d5ca1beeabb9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4 PrefetchStages; } - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) { if(num_loop % HotloopUnroll == 1) { @@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 8569b680e5ef69b1e237ed42ccdef541f91e048a..b6a4f05502ad03d28a98ca6d351ee613e0c4f56c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5(), b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); @@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run( a_block_desc_m0_m1_m2_k, @@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); a_thread_copy_.Run( @@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index f8ee283c67cd6dc0b3d8b484797bd2d24a2f7996..3ea19da741a0c74ce7ed2a05041cdc4bf362b4bb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,504 @@ namespace ck { +#ifdef __gfx12__ +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; + + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else template (), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -411,10 +908,9 @@ struct BlockwiseGemmWMMA constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - wmma_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -529,5 +1025,6 @@ struct BlockwiseGemmWMMA typename AThreadCopySelector::type a_thread_copy_; typename BThreadCopySelector::type b_thread_copy_; }; +#endif } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 701dd04f6ca19b2de8ed8a630149cd725ac95c3f..d3f6344c27eb56ef265aca802687ae3757881e96 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -488,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // sync point. if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("s_barrier" ::); +#endif __builtin_amdgcn_sched_barrier(0); } static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { @@ -530,10 +536,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { __builtin_amdgcn_sched_barrier(0); @@ -795,11 +800,6 @@ struct BlockwiseGemmXdlops_v2 "wrong!"); } - __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) - : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) - { - } - // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() { @@ -968,10 +968,9 @@ struct BlockwiseGemmXdlops_v2 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 8ae1ba3f34c10d8359b470ed1172f1b70a7fa8b5..287c6701c3391de0d5b34d7a4da55bba76ae7d6a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46d0c6ac2eb4eeb0dfad627834b6d7fbaa991ee7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r3 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index adfa1689c66509c8c194985365356740d4c90473..0eef827a5b5018efeee94d11f1b768c739d85648 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ enum struct ConvolutionForwardSpecialization Filter1x1Pad0, Filter1x1Stride1Pad0, OddC, + Filter3x3, }; inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) @@ -25,6 +26,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; case ConvolutionForwardSpecialization::OddC: return "OddC"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a4d684f1457fb89a81e4b15e723b7969be1cdde --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Streamk_V2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t Streamk_sel, + ck::index_t Grid_size, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/helper.hpp b/include/ck/tensor_operation/gpu/device/helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c52566509f41b3bbe5020dba215198a0f04f46ac --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/helper.hpp @@ -0,0 +1,359 @@ +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include +#include + +// functions to return the corresponding structs based on generated template parameters + +using layouts = std::variant; +// return the layout type: currently this is the only type supported in MIOpen +auto layout_type(std::string type) +{ + if(type == "ck::tensor_layout::convolution::NHWGK") + { + return ck::tensor_layout::convolution::NHWGK{}; + } + throw std::runtime_error("Incorrect layout"); +} +// return the right gemm spec based on the generated template parameters +ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type) +{ + if(type == "ck::tensor_operation::device::GemmSpecialization::Default") + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding") + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + throw std::runtime_error("Incorrect gemm spec: " + type); +} + +// return the type of convolution +ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type) +{ + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + } + if(type == + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + } + throw std::runtime_error("Incorrect conv spec: " + type); +} + +// Function to call on MatrixPadder via a wrapper struct +// NOTE: CK only uses MNKPadding for forward convolution +template +auto pad(ck::index_t mpb, + ck::index_t npb, + ck::index_t kpb, + ck::tensor_operation::device::GemmSpecialization gemm, + CDesc_MRaw_NRaw conv) +{ + if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding) + { + ck::tensor_operation::device::MatrixPadder< + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + ck::index_t, + ck::index_t, + ck::index_t> + a; + a.MPerTile_ = mpb; + a.NPerTile_ = npb; + a.KPerTile_ = kpb; + auto tmp = grid_desc(a, conv); + return tmp; + } + throw std::runtime_error("Incorrect template parameters, check gemm spec"); +} + +// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num +// dims +// FIXME: add a way to properly pass in the layout +auto transform_conv(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_3d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_1d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(out_lengths, out_strides, conv_fwd); + } + throw std::runtime_error("Incorrect dims or conv spec"); +} + +template +auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder) +{ + if(m_per_block == 32 && n_per_block == 64) + { + auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 32 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 256) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 256 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + throw std::runtime_error("Incorrect template parameters"); +} + +// wrapper functions by dims to get grid size - uses above 3 functions +// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions +auto get_launch_params_1d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params_3d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7ef4e7f1848388ceaec7915317c53002ea37b960 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ + + device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AsPointer, // tuples if multi AB, pointers if no + BsPointer, + DsPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfBatch, + HasMainKBlockLoop, + isMultiA, + isMultiB>(p_as_grid, + p_bs_grid, + p_ds_grid, + *p_e_grid, + a_element_op, + b_element_op, + cde_element_op, + batch_count, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + compute_ptr_offset_of_batch); +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument + { + __device__ __host__ Argument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + ck::Array a_g_n_c_wis_lengths_; + ck::Array a_g_n_c_wis_strides_; + ck::Array b_g_k_c_xs_lengths_; + ck::Array b_g_k_c_xs_strides_; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_; + ck::Array e_g_n_k_wos_lengths_; + ck::Array e_g_n_k_wos_strides_; + ck::Array conv_filter_strides_; + ck::Array conv_filter_dilations_; + ck::Array input_left_pads_; + ck::Array input_right_pads_; + }; + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index d35645c06893fdcc5ed353e6a63f5b5f7d6c5803..ab3f3856aaab2eb54e7759a05be0ba82713822dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; - static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } else { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + if(!(arg.a_kz_stride_ == 1)) { - printf("DeviceOp: Vector Access A-k check failure\n"); - return false; + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index b01e029c03d1eb4b863eb2e8fdb2ba684bbda902..1b487502f4f94a78f8ececd7d444320fcfeaf1b5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -70,8 +70,9 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD{}([&](auto i) { using D0Layout = remove_cvref_t>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index d491ee2ea79398e266111c64edd0950cef4b8791..34b1d503afe78e324d43f5fb7df6531809756e99 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; + { + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; - std::cout << "arg.a_grid_desc_ak0_m_ak1_{" - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_bk0_n_bk1_{" - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) - << "}" << std::endl; + std::cout << "arg.reduce_grid_desc_m_{ " + << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl; + } } -#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp index e218ee5c15da0c916a46c52d3375933490ac2dbb..1026118381bc8b855bf1d3fa5856478a044bce44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -56,7 +56,7 @@ __global__ void bool input_permute, bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -159,6 +159,7 @@ __global__ void ignore = O; ignore = G0; ignore = G1; + ignore = alpha; ignore = input_permute; ignore = output_permute; #endif // end of if (defined(__gfx11__)) @@ -187,7 +188,7 @@ __global__ void index_t head_size, float alpha) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -321,7 +322,7 @@ __global__ void index_t head_size, float alpha) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -1435,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 00a89c47b895d13c03eca83d54b9d6971dea4d8c..e178b8f5252781ead149f5d2b78f8fc53125a3af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DEBUG_LOG - arg.Print(); -#endif + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + arg.Print(); + } if(!ck::is_xdl_supported()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 4c6546239beb642f680b5dc639185ebe2b6e0ccb..a7a366ffbc4229df1d0b7a32de99c2bac7a8b452 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl independent_filter_strides, conv_filter_dilations, input_left_pads_with_offset, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 33e03a85e2e5e959daebfd7923312dc27f5870d6..dae16612ccaeb92273895d2de599261f76c43e9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { - as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; - as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1; - as_max_read_elems_[i] = + tie(as_continous_dim_[i], as_max_read_elems_[i]) = CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); } for(index_t i = 0; i < NumBTensor; ++i) { - bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; - bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1; - bs_max_read_elems_[i] = + tie(bs_continous_dim_[i], bs_max_read_elems_[i]) = CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); } for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); } // pointers @@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - std::array as_mz_consecutive_; - std::array as_kz_consecutive_; - std::array bs_nz_consecutive_; - std::array bs_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + std::array as_continous_dim_; + std::array bs_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; std::array as_max_read_elems_; std::array bs_max_read_elems_; @@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_a_vector_size = arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; const bool valid_a_access_dim_m = - ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; + ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0; const bool valid_a_access_dim_k = - ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; + ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; if(!((valid_a_vector_size && valid_a_access_dim) || ABlockTransferSrcScalarPerVector == 1)) @@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_b_vector_size = arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; const bool valid_b_access_dim_n = - BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; + BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0; const bool valid_b_access_dim_k = - BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; + BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; if(!((valid_b_vector_size && valid_b_access_dim) || BBlockTransferSrcScalarPerVector == 1)) @@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_d_vector_size = arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. - const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1; if(!((valid_d_vector_size && valid_d_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { @@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_e_vector_size = arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. - const bool valid_e_access_dim = arg.e_nz_consecutive_; + const bool valid_e_access_dim = arg.e_continous_dim_ == 1; if(!((valid_e_vector_size && valid_e_access_dim) || CDEBlockTransferScalarPerVector_NPerBlock == 1)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 4cc60f2836bf4087ce376e30a43f45bfca35164d..f0f89f1d1b45b4128647f2ccee5e3a9580655ac4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,8 +53,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -443,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; - a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1; - a_max_read_elems_ = + tie(a_continous_dim_, a_max_read_elems_) = CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; - b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1; - b_max_read_elems_ = + tie(b_continous_dim_, b_max_read_elems_) = CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; - ds_max_read_elems_[i] = + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; - e_max_write_elems_ = + tie(e_continous_dim_, e_max_write_elems_) = CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } @@ -502,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Describe whether the last part of a given dimension of A/B/D/E is consecutive - // in the memory or not. - bool a_mz_consecutive_; - bool a_kz_consecutive_; - bool b_nz_consecutive_; - bool b_kz_consecutive_; - std::array ds_nz_consecutive_; - bool e_nz_consecutive_; + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + index_t a_continous_dim_; + index_t b_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; index_t a_max_read_elems_; index_t b_max_read_elems_; @@ -602,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && - ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && - std::is_same::value) + if(!ck::is_lds_direct_load_supported() && std::is_same::value) { return false; } @@ -625,8 +613,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_a_vector_size = arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; - const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; - const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; if(!(valid_a_vector_size && valid_a_access_dim)) @@ -636,8 +626,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const bool valid_b_vector_size = arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; - const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; - const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; if(!(valid_b_vector_size && valid_b_access_dim)) @@ -651,7 +643,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. const bool valid_d_access_dim = - arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_d_vector_size && valid_d_access_dim)) { valid_ds_access = false; @@ -666,7 +658,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. const bool valid_e_access_dim = - arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; + arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; if(!(valid_e_vector_size && valid_e_access_dim)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 838305f187adcad799974cad5085ba0596e8dfbc..1b0db73fdd097210c4a31a9c456dd675b300ceb8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector= begin_idx; --dim_idx) { if(strides[dim_idx] == consecutive_stride) @@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index bac124a2f1f985d16f2f2a1fed28f8bdadf7b3ce..eb0fb55f5dbc8e547a7b3e0d591c423fb4f85227 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2b5317dd9630925c9409e843d778d1d68708292 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,730 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<{}; - static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - - static constexpr auto AEnableLds_auto = - (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; static constexpr auto BEnableLds_auto = - (MWaves == 1 && is_same::value) ? false : true; + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; // If true, LDS is used unconditionally static constexpr auto AEnableLds_manu = false; @@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..452063156e2ecfa983ddc99d6896adf18505d1c2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + float ave_time = 0; + + index_t k_grain = KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + hipGetErrorString(hipMemsetAsync( + arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); + const auto Run = [&](const auto& kernel) { + dim3 grid_dim; + if(arg.Grid_size < 0) + { + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, 0); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + arg.Grid_size = num_cu * occupancy; + grid_dim = arg.Grid_size; + } + else + grid_dim = arg.Grid_size; + + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_); + } + else + { + + ave_time = launch_and_time_kernel( + stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + + return Argument{ + p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c0fa9ad882a1263e44d76103b479d7b4af3f1336..5e9da459c0bd478d397ca1f9dea0a74795056c21 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,12 +93,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index bd264a3c8158c102959277a4f04977501a040edc..cc26936fef8489caf61d8ee430cd8f581f8fc746 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -48,18 +48,19 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a5ae0565f3cb2a2e0b56936207169298c0881570..7f88ea692a28daf3c587826209172d6431380576 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -66,12 +66,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; @@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle K0PerBlock, ConvBackwardWeightSpecialization>{}; + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32) + : CBlockTransferScalarPerVector_NWaveNPerXdl; + // Bytes per 32 lds bank: 32 * 4 bytes static constexpr auto BankLength = 128; static constexpr auto ElePerBank = BankLength / sizeof(ADataType); @@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ADataType, BDataType, AccDataType, - EDataType, + AccDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, @@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXdl, + WorkspaceInOutScalarPerVector, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, @@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static constexpr auto MakeElementwiseInputSequence() { return generate_sequence_v2( - [&](auto) constexpr { return Number{}; }, + [&](auto) constexpr { return Number{}; }, Number{}); } @@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); - using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); using EGridDesc_M_N = CGridDesc_M_N; static constexpr index_t ClusterLengthMPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle std::size_t GetWorkspaceSizeBytes() const { - return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; } const ADataType* p_a_grid_; @@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { - EDataType* p_c_grid = type_convert(arg.p_workspace_); + AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle GridwiseGemm, ADataType, BDataType, - EDataType, + AccDataType, OutElementwiseOperation, InElementwiseOperation, element_wise::PassThrough, @@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle }; auto launch_elementwise_kernel = [&]() { - const EDataType* p_c_grid = type_convert(arg.p_workspace_); + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle } // vector store C matrix into global memory - if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e18b8b9e283c0c9fa05fe4ee0faec883c7d3bbca --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,1604 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = + TransformConvBwdWeightToGemm{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; + + using GridwiseGemm = + GridwiseGemm_xdl_cshuffle_v3; + + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwise = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_e_grid_{p_wei_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; + + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + } + + std::size_t GetWorkspaceSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2TileMapElementwise elementwise_block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + I0, + arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; + + std::array in_out_batch_strides = { + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; + + const auto kernel = kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + }; + + float avg_time = RunGemmV3(arg, stream_config); + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + std::cerr << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWK_GKXC_GNWC()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGK_GKYXC_NHWGC() || + is_GNHWK_GKYXC_GNHWC())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC() || + is_GNDHWK_GKZYXC_GNDHWC())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + return false; + } + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + return false; + } + } + + if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumGroupsToMerge + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index b9436c21a40411118bfce4a61aae2f461dc5b174..5738be0fb327f167a4f86cb66724cfbf0fa67542 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 96854e9a8d14c6cb191a948ca53d3c67df146a88..3babd1896f8e90323949c49c21e4c33a48e988f0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,12 +61,9 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index c3023301f39f9edde9ecc9c0aadde9b14609861b..c3fe54b0759e194ecdca65f57aee314bcb9383de 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -90,19 +90,20 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -266,7 +267,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -312,8 +314,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -666,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || - ck::is_navi2_supported() || ck::is_navi3_supported())) + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index d731e5ddac321dc093501b40c935b1ebb162a856..c6b84b613c7812dc47c5924bc4cc9fac68fb5d0f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -107,7 +107,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx11__)) + defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd& c_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(c_g_n_k_wos_lengths, - c_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -601,8 +602,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd @@ -85,7 +86,6 @@ __global__ void const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, - const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -93,18 +93,21 @@ __global__ void const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) + // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -114,30 +117,45 @@ __global__ void DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); if constexpr(isMultiA || isMultiB) { AsPointer p_as_grid_grp; BsPointer p_bs_grid_grp; - const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); - static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); - static_for<0, NumATensor, 1>{}( - [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); + } - const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static_for<0, NumBTensor, 1>{}( - [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); GridwiseGemm::template Run( p_as_grid_grp, p_bs_grid_grp, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -150,16 +168,19 @@ __global__ void } else { - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); GridwiseGemm::template Run( - p_as_grid + a_batch_offset, - p_bs_grid + b_batch_offset, + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset, p_ds_grid_grp, - p_e_grid + e_batch_offset, + p_e_grid + e_group_offset + e_n_offset, p_shared, a_element_op, b_element_op, @@ -175,7 +196,6 @@ __global__ void ignore = p_bs_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = batch_count; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; @@ -183,7 +203,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; ignore = block_2_ctile_map; #endif } @@ -261,7 +282,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD= 1); + static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; @@ -293,7 +317,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr auto I3 = Number<3>{}; static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; + TransformConvFwdToGemm{}; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -309,7 +333,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t Conv_N) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -321,7 +346,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + Conv_N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -347,11 +373,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template static auto MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -363,24 +390,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Pass e_g_n_k_wos_lengths for logical broadcast. static auto MakeDsGridDescriptor_M_N( const std::array& e_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const index_t Conv_N) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - ds_g_n_k_wos_strides[i]); + return DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N); }, Number{}); } // desc for problem definition using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; // If we are using multiAB and one of the template datatype parameters is not a tuple, convert // it to it @@ -468,6 +496,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, b_g_k_c_xs_lengths, @@ -477,12 +511,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads)}, + input_right_pads, + conv_N_per_block_)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, a_grid_desc_ak0_m_ak1_{ GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, b_grid_desc_bk0_n_bk1_{ @@ -490,7 +525,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - compute_ptr_offset_of_batch_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -511,8 +547,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if constexpr(isMultiA || isMultiB) { static_for<0, NumATensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideA_(i) = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // type is not tuple) @@ -524,16 +561,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { // p_as is tuple p_as_grid_(i) = static_cast(p_as[i.value]); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + compute_ptr_offset_of_n_.BatchStrideA_(i) = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } else { // if MultiB and not MultiA then p_as is single pointer p_as_grid_(i) = static_cast(p_as); + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides[1] * conv_N_per_block_; } }); static_for<0, NumBTensor, 1>{}([&](auto i) { - // Init compute_ptr_offset_of_batch_ for multiple AB - compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideB_(i) = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; using DataType = remove_cvref_t>; // It is possible that one of the AB is a pointer and one is a tuple. @@ -553,8 +598,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } else { - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + b_g_k_c_xs_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; // p_as and p_bs are pointers p_as_grid_(I0) = static_cast(p_as); @@ -570,13 +618,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle p_ds_grid_(i) = static_cast(p_ds[i]); // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; // D desc ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); }); - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; // populate desc for Ds/E if constexpr(isMultiA || isMultiB) @@ -638,6 +690,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // tensor descriptors for problem definiton index_t num_group_; + index_t conv_N_per_block_; + AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; @@ -655,7 +709,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // for computing batch offset ComputePtrOffsetOfStridedBatch - compute_ptr_offset_of_batch_; + compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -689,8 +744,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.Print(); } - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + const index_t gdy = arg.num_group_ / NumGroupsToMerge; + const index_t gdz = num_workgroups_per_Conv_N; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -721,6 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -728,7 +788,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_, @@ -738,13 +798,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count as_grid_desc_ak0_m_ak1, bs_grid_desc_bk0_n_bk1, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } else { @@ -763,6 +823,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, isMultiA, isMultiB>; @@ -770,7 +831,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return launch_and_time_kernel( stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple @@ -780,13 +841,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle arg.a_element_op_, arg.b_element_op_, arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; @@ -811,6 +872,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // check device if(get_device_name() == "gfx908") { @@ -820,15 +885,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return false; } } - else if(ck::is_lds_direct_load_supported()) - { - if constexpr(!(is_same_v || is_same_v || - is_same_v || is_same_v)) - { - return false; - } - } - else + if(!ck::is_xdl_supported()) { return false; } @@ -867,6 +924,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + return false; + } + if(G % NumGroupsToMerge != 0) + { + return false; + } + if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC()) + { + return false; + } + } // check vector access of A // FIXME: layout @@ -876,11 +969,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t C = arg.a_g_n_c_wis_lengths_[2]; - + // Check access per C if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { - return false; + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && C == 1 && + is_NSpatialGK_GKSpatial_NSpatialGC() && + G % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } } } else @@ -897,8 +995,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v) { - const index_t C = arg.b_g_k_c_xs_lengths_[2]; - if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { return false; @@ -922,8 +1018,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v) { - const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { valid = false; @@ -968,8 +1062,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v) { - const index_t K = arg.e_g_n_k_wos_lengths_[2]; - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { return false; @@ -1120,7 +1212,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle << BBlockTransferSrcScalarPerVector << ", " << CDEBlockTransferScalarPerVector_NPerBlock << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle + << CShuffleNXdlPerWavePerShuffle << ", " + << NumGroupsToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a4d4a01a0117eb4aecde290a5ee7ecc340178e27 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,1156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, + [[maybe_unused]] const index_t groups_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); + const index_t& num_blocks_per_n = groups_count; + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiD = DsDataType::Size() > 0; + static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + + // multi ABD not supported + static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported"); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t Conv_N) + + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + Conv_N); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const index_t Conv_N) + + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + using EGridDesc_M_N = remove_cvref_t({}, {}, 1))>; + +#define GridwiseGemmV3TemplateParams \ + tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ + tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ + MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ + AComputeDataType, BComputeDataType + + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const index_t M = e_grid_desc_m_n.GetLength(I0); + const index_t N = e_grid_desc_m_n.GetLength(I1); + return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); + } + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t({}, {}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_as, + const void* p_bs, + const std::array&, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{}, + p_b_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_N_per_block_{ + conv_to_gemm_transformer.template GetSplitedNSize( + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + conv_N_per_block_)}, + b_grid_desc_bk0_n_bk1_{ + MakeBGridDescriptor_BK0_N_BK1(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + + // p_as and p_bs are pointers + p_a_grid_ = static_cast(p_as); + p_b_grid_ = static_cast(p_bs); + + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + index_t conv_N_per_block_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_M_N e_grid_desc_m_n_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + + gdy *= arg.num_group_ * num_workgroups_per_Conv_N; + + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + typename GridwiseGemm::Argument gemm_arg{ + arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1}; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + arg.num_group_); + } + else + { + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_, + arg.num_group_); + } + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ab1c4fc08f60bfae145c80104e593e2d55115971..2170a5829a59c429be52d9fe84f58e041a5ebdac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -161,11 +161,11 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t b_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index d70d462e24e4bd13b3240f918ec045baa5c3e41b..9bab947fdb55184e5058dfc9c8ec7be557db619e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + a_g_n_c_wis_lengths[I1]); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); @@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const std::array& e_g_n_k_wos_strides) { const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); + conv_to_gemm_transformer.template MakeCDescriptor_M_N( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]); const auto out_gemmm_gemmn_desc = matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); @@ -581,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle namespace ctc = tensor_layout::convolution; // check device - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp index 9ae10441f95f80f0b95288e55a5cdd261384dfe0..3ee02558f4040a29e15799403bae90afd23b2d79 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC() is_same_v; } +template +constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC() +{ + return is_NWGK_GKXC_NWGC() || + is_NHWGK_GKYXC_NHWGC() || + is_NDHWGK_GKZYXC_NDHWGC(); +} + +template +constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC() +{ + return is_GNWK_GKXC_GNWC() || + is_GNHWK_GKYXC_GNHWC() || + is_GNDHWK_GKZYXC_GNDHWC(); +} + template struct ComputePtrOffsetOfStridedBatch { @@ -68,14 +84,14 @@ template struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> + enable_if_t<(NumATensor > 1 || NumBTensor > 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, - Array& BatchStrideBs, - Array& BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideAs), BatchStrideB_(BatchStrideBs), BatchStrideDs_(BatchStrideDs), @@ -87,7 +103,7 @@ struct ComputePtrOffsetOfStridedBatch as_offset; static_for<0, NumATensor, 1>{}( - [&](auto i) { as_offset(i) = g_idx * static_cast(BatchStrideA_[i]); }); + [&](auto i) { as_offset(i) = static_cast(g_idx) * BatchStrideA_[i]; }); return as_offset; } @@ -95,7 +111,7 @@ struct ComputePtrOffsetOfStridedBatch bs_offset; static_for<0, NumBTensor, 1>{}( - [&](auto i) { bs_offset(i) = g_idx * static_cast(BatchStrideB_[i]); }); + [&](auto i) { bs_offset(i) = static_cast(g_idx) * BatchStrideB_[i]; }); return bs_offset; } @@ -103,40 +119,40 @@ struct ComputePtrOffsetOfStridedBatch ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - Array BatchStrideA_; - Array BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template struct ComputePtrOffsetOfStridedBatch> + enable_if_t<(NumATensor == 1 && NumBTensor == 1)>> { ComputePtrOffsetOfStridedBatch() = default; - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + Array BatchStrideDs, + long_index_t BatchStrideE) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideDs_(BatchStrideDs), @@ -146,38 +162,38 @@ struct ComputePtrOffsetOfStridedBatch(BatchStrideA_); + return static_cast(g_idx) * BatchStrideA_; } __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideB_); + return static_cast(g_idx) * BatchStrideB_; } __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const { Array ds_offset; static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); return ds_offset; } [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } // alias for kernels without multiple D [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const { - return g_idx * static_cast(BatchStrideE_); + return static_cast(g_idx) * BatchStrideE_; } - ck::index_t BatchStrideA_; - ck::index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; - index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D }; template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index bf8788a3b2c3624ba51d2e54aca6b1e7ae72df31..1f60818e39fcc7bdb9229fd38464aeae4386028e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -45,8 +45,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t KBatch = 1; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 6f7d7c389404610ed7b1511ffc7656ed7f9f3d1c..060a16d1e21e9265481467bb6e03f0341c3cf886 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -553,24 +554,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { -#if DEBUG_LOG - std::cout << "The group count is not equal to sum of skipped groups " - "and kernel args size!" - << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } return false; } @@ -832,11 +836,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { -#if DEBUG_LOG - std::cout << "[" << __func__ << "] group id: " << i - << " has invalid GridwiseGemm settings!" << std::endl; - gemm_arg.Print(); -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); + } } supported = supported && group_arg_valid; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 0a0e8072bfc8576b8bb53d5acf321e028f5d1ae3..70011124fc0275e870327a83d5c2a19cc62393ef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" @@ -42,16 +43,22 @@ namespace device { template + typename CDEElementwiseOperation, + BlockGemmPipelineScheduler BlkGemmPipeSched, + BlockGemmPipelineVersion BlkGemmPipelineVer> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -67,6 +74,7 @@ __global__ void constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); @@ -81,27 +89,8 @@ __global__ void index_t gemm_tile_id_start = 0; index_t gemm_tile_id_end = 0; - using AGridDescMK = - remove_cvref_t( - 1, 1, 1))>; - using BGridDescNK = - remove_cvref_t( - 1, 1, 1))>; - using EGridDescMN = - remove_cvref_t( - 1, 1, 1))>; - using DsGridDescMN = - remove_cvref_t( - {}, {}, {}))>; - index_t M = 0, N = 0, K = 0; - index_t StrideA, StrideB, StrideE; - std::array StrideDs; - AGridDescMK a_grid_desc_mk; - BGridDescNK b_grid_desc_nk; - EGridDescMN e_grid_desc_mn; - DsGridDescMN ds_grid_desc_mn; auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); do @@ -127,31 +116,13 @@ __global__ void } b2c_tile_map = - OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset); + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); gemm_tile_id_start = group_offset; gemm_tile_id_end = group_offset + grid_size_grp; } - StrideA = gemm_desc_ptr[group_id].StrideA; - StrideB = gemm_desc_ptr[group_id].StrideB; - StrideDs = gemm_desc_ptr[group_id].StrideDs; - StrideE = gemm_desc_ptr[group_id].StrideE; - - a_grid_desc_mk = - GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); - b_grid_desc_nk = - GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); - e_grid_desc_mn = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - static_for<0, NumDTensor, 1>{}([&](auto j) { - using DLayout = remove_cvref_t>; - ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N( - M, N, StrideDs[j]); - }); - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); DsGridPointer p_ds_grid; @@ -160,42 +131,268 @@ __global__ void p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); }); - bool has_main_kblock_loop = - GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{})); + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + // Update tile offset if we have moved within group b2c_tile_map.UpdateTileOffset(tile_offset); - if(has_main_kblock_loop) + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid, - gemm_desc_ptr[group_id].p_e_grid, - static_cast(p_shared), - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } } else { - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid, - gemm_desc_ptr[group_id].p_e_grid, - static_cast(p_shared), - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } } tile_id += get_grid_size(); @@ -253,10 +450,12 @@ template + typename CDEShuffleBlockTransferScalarPerVectors, + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + typename ComputeTypeA = EDataType, + typename ComputeTypeB = ComputeTypeA> + struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop : public DeviceGroupedGemmTileLoop; + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; - template - struct OffsettedBlockToCTileMap - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, - index_t group_offset, - index_t tile_offset) - : block_to_ctile_map_{block_to_ctile_map}, - group_offset_{group_offset}, - tile_offset_{tile_offset} - { - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - return block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - return block_to_ctile_map_.CalculateGridSize(M, N); - } - - __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t group_offset_; - index_t tile_offset_; - }; - - using KernelArguments = GroupedGemmTileLoopKernelArguments; - using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt; - using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArguments = GroupedGemmTileLoopKernelArguments; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; // Argument struct Argument : public BaseArgument @@ -375,7 +533,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop std::vector& /* p_Bs */, std::vector>& /* p_Ds */, std::vector& /* p_Es */, - std::vector& gemm_descs, + const std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, @@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const void* p_dev_gemm_args_; int occupancy_num_blocks_; int gpu_cu_count_; - const std::vector& gemm_descs_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); } @@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop << std::endl; } + // run multiple kernels + return launch_and_time_kernel(stream_config, kernel, dim3(grid_size), @@ -572,61 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return false; } - using DsGridDescMN = remove_cvref_t< - decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N( - {}, {}, {}))>; - bool supported = true; - for(const auto& gdesc : arg.gemm_descs_) + constexpr index_t k_batch = 1; + for(index_t i = 0; i < arg.group_count_; ++i) { - const auto M = gdesc.M_; - const auto N = gdesc.N_; - const auto K = gdesc.K_; - - const auto StrideA = gdesc.stride_A_; - const auto StrideB = gdesc.stride_B_; - const auto StrideE = gdesc.stride_C_; - const auto& StrideDs = gdesc.stride_Ds_; - - // If M dimension is unknown at launch time then validate just NK. - // If N or K dim is zero (or unknown) then the vector loads responsibility lies on - // the user. - if(N * K == 0) - continue; - - const auto a_grid_desc_mk = - GridwiseGemm::template MakeAGridDescriptor_M_K(M, K, StrideA); - const auto b_grid_desc_nk = - GridwiseGemm::template MakeBGridDescriptor_N_K(K, N, StrideB); - const auto e_grid_desc_mn = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - DsGridDescMN ds_grid_desc_mn; - static_for<0, NumDTensor, 1>{}([&](auto j) { - using DLayout = remove_cvref_t>; - ds_grid_desc_mn(j) = - GridwiseGemm::template MakeEGridDescriptor_M_N( - M, N, StrideDs[j]); - }); - - const auto b2c_tile_map = Block2ETileMap(M, N); - - if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk, - b_grid_desc_nk, - ds_grid_desc_mn, - e_grid_desc_mn, - b2c_tile_map) && - GridwiseGemm::template CheckTensorTransfersValidity( - M, N, K))) + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + using GridArg = typename GridwiseGemm::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) { -#if DEBUG_LOG - std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K - << "] are not supported by current template parameters!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; -#endif - supported = false; + return false; } + + supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); } return supported; @@ -641,7 +786,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop std::vector& p_Bs, std::vector>& p_Ds, std::vector& p_Es, - std::vector gemm_descs, + std::vector& gemm_descs, AElementwiseOperation a_elementwise_op, BElementwiseOperation b_elementwise_op, CDEElementwiseOperation cde_elementwise_op) @@ -649,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; int occupancy, num_cu; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -694,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + CDEElementwiseOperation, + BlkGemmPipeSched, + BlkGemmPipelineVer>; int occupancy, num_cu; hip_check_error( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); @@ -737,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop { auto str = std::ostringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" << "<" @@ -758,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", " << getGemmSpecializationString(GemmSpec) << ", " - << PipelineVer << ", " - << LoopSched + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 7dfb677ecc18530a04a40502fa923d243fec4def..658f3235168f247ef888552a2a07a72f5e6fc0f5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { -#if DEBUG_LOG - std::cout << "The group count is not equal to sum of skipped groups " - "and kernel args size!" - << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } return false; } @@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK || is_same_v)) { @@ -958,7 +959,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 52aeefa3a41af3b4032798344790aaa9fa0150e3..9ebcb2b8c08245efb32d22404a13e2b315268d6b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads); + input_right_pads, + N); const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index b7551e78a22a240fb5c3f345931575e360ccbbe1..cc88c1a10473a86dc7fd0bdc54c68f25c17a9682 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -60,7 +60,7 @@ __global__ void bool input_permute, bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -165,6 +165,7 @@ __global__ void ignore = O; ignore = G0; ignore = G1; + ignore = alpha; ignore = input_permute; ignore = output_permute; #endif // end of if (defined(__gfx11__)) @@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma static bool IsSupportedArgument(const RawArg& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -950,7 +951,7 @@ struct DeviceMultiQueryAttentionForward_Wmma #if 0 static bool IsSupportedArgument(const Argument& arg) { - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index c66d2fc516babe3aa5a16de053d9ada89eabf8ef..02941531475ffe71cf76f96a7259facfaccbec00 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder +auto grid_desc(MatrixPadder matrix_padder, + CDesc_MRaw_NRaw conv_desc) +{ + auto res = matrix_padder.PadCDescriptor_M_N(conv_desc); + return res; +} // M/N/KPerTileType could be index_t or Number<> template } }; -struct ConvInvscale -{ - /// @brief Op to multiply convolution results by inverted scale factors - /// @param e Output after scaling - /// @param c Convolution result - /// @param d0 Input scale factor - /// @param d1 Weights scale factor - /// @param d2 Output scale factor - template - __host__ __device__ void - operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; - - template <> - __host__ __device__ void operator()( - f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const - { - e = type_convert(c / d0 / d1 / d2); - }; -}; - } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index bddf9087fe087382003f859c23f3855d8629da23..bf4a1c800fb0ce81b79118897a8f79502aa6d7ee 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -431,7 +431,7 @@ struct Relu // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // host code use higher accuracy "exp" and "div" -// gpu code use lower accuracy "__expf" and "rcp" function +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function struct FastGelu { template @@ -451,7 +451,7 @@ struct FastGelu y = x / (1.f + emu); } - // device code, use lower precision "__expf" and "rcp" + // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { @@ -459,7 +459,7 @@ struct FastGelu const float c1 = -2.0 * 0.035677f; const float c2 = -2.0 * 0.797885f; const float u = x * (c1 * x * x + c2); - const float emu = __expf(u); + const float emu = __ocml_exp_f32(u); y = x * ck::math::rcp(1.f + emu); } @@ -961,6 +961,95 @@ struct Elu const float alpha_; }; +struct Logistic +{ + Logistic(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + const float alpha_; +}; + +struct ConvInvscale +{ + __host__ __device__ ConvInvscale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScale +{ + __host__ __device__ ConvScale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScaleRelu +{ + __host__ __device__ ConvScaleRelu(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + float x; + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + // support fastconvert of int8 to fp16 template diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index d92f504d52ea754edadc6b258085022d638c12d1..56c37b1b7240de0e85fe2ddd6faa7264deaaa32e 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_Grouped_M00_N0_M01Adapt @@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap UnderlyingBlockToCTileMap block_to_ctile_map_; index_t block_start_; }; +// second version with 2 offsets +template +struct OffsettedBlockToCTileMap2 +{ + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t group_offset, + index_t tile_offset) + : block_to_ctile_map_{block_to_ctile_map}, + group_offset_{group_offset}, + tile_offset_{tile_offset} + { + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t group_offset_; + index_t tile_offset_; +}; /** * @brief Simple tile mapping which creates 3D grid of block of threads. @@ -1359,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK } }; +template +struct BlockToCTileMap_GemmStreamK_v2 +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + mutable uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv equiv_tiles_big; // for reduction + MDiv equiv_tiles_little; // for reduction + + // prefer construct on host + __host__ __device__ BlockToCTileMap_GemmStreamK_v2( + uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) + { + // total output tiles + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + + // default to regular DP GEMM if sk blocks == 0 + if(streamk_sel == 0) + { + sk_num_blocks = 0; + dp_tiles = num_tiles; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + // 2-tile sk + DP GEMM + else + { + + // check if there's enough work for DP+ stream-k + bool bigEnough = num_tiles > grid_size; + // select between stream-k strategies + uint32_t sk_tiles = 0; + if(streamk_sel == 1) // 1 tile stream-k + { + sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 2) // 2-tile stream-k + { + sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + } + else if(streamk_sel == 3) // 3-tile stream-k + { + sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) + : num_tiles; + } + else if(streamk_sel == 4) // 4-tile stream-k + { + sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) + : num_tiles; + } + sk_num_blocks = sk_tiles; + // remaining tiles are DP tiles + dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; + + sk_total_iters = k_iters_per_tile.get() * sk_tiles; + + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = sk_num_blocks; + } + + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + // using multiple blocks for parallel reduction + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ index_t get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + return reduction_start_block_idx + get_sk_tiles(); + } + else + return reduction_start_block_idx; + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& equiv_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 16717ff8197244024601a641d0d40c9b645f51da..1754e07e6a2b98a2de3b0056cfda509c76eea0c2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B0EnableLds) { // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if constexpr(B1EnableLds) { // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 - constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); - constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else constexpr auto B_LRow = I1; +#endif return transform_tensor_descriptor( B1BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 67e211ef8d620a279f94088478cfb9c9a51ab05d..21dac6f9e9eb59ae5af6bc81aea5e496231fc3b8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -50,8 +50,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -80,7 +79,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } // Assume B is Col-Major @@ -303,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -361,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 82d010a99a08c5325b77aebfede3611643f519de..b3b057c80a49c2e9f17938d2b839dc1856516435 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,18 +54,18 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -147,7 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; @@ -155,12 +155,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -237,7 +237,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; GridwiseOp::template Run(p_a_grid, @@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma } else { + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } @@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 8e4117593c43e22ff9a43f2d107f112586b8c5a6..4458b9356dd64000997a93c23fc9241d1e27d8b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -45,7 +45,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, @@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma } else { + constexpr auto A_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma } else { + + constexpr auto B_KRow = I2; constexpr auto KWmmaPerblock = KPerBlock / WmmaK; - constexpr auto K0PerWmma = WmmaK / 2 / K1; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread return make_naive_tensor_descriptor( make_tuple(Number{}, @@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma if constexpr(AEnableLds) { // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma if constexpr(BEnableLds) { // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 - constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else constexpr auto B_KRow = I1; +#endif return transform_tensor_descriptor( BBlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), make_unmerge_transform(make_tuple( Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - using DefaultBlock2CTileMap = - remove_cvref_t; - struct SharedMemTrait { // LDS allocation for A and B: be careful of alignment @@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma b_block_space_size_aligned * sizeof(BDataType)); }; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + template __device__ static void Run(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2a06ba9afe5b89be9b19c3bb1183d7b21c219fb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -0,0 +1,1369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff10215353cf1b47d598e7dc28bd849ee139e90b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,2010 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_streamk_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + Streamk_sel{Streamk_sel_}, + Grid_size{Grid_size_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, 1)}, + KPadded{CalculateKPadded(K_, 1)}, + AK0{CalculateAK0Padded(K_, 1)}, + BK0{CalculateBK0Padded(K_, 1)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel + << ", Grid size:" << Grid_size << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t Streamk_sel; + mutable index_t Grid_size; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K) + { + if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead * problem.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead * problem.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead; + } + + if(kbatch_id < static_cast(problem.KBatch - 1)) + { + problem.K = problem.KRead; + } + else + { + problem.K = orig_K - problem.KRead * (problem.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + + if(karg.K <= 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(is_same, bhalf_t>::value) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + Problem& problem) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + Block2CTileMap_streamk block_2_ctile_map_streamk( + problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; + index_t num_k_block_main_loop; + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + { + + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(problem.K, + problem.KPadded, + problem.N, + problem.NPadded, + problem.StrideB, + problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index cadfb29bbfb659d2c133f6a07a6500b2537941fd..bda2ded95715b11744a15a847fdc53ab1aeb0ad1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,8 +34,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -48,7 +47,7 @@ __global__ void karg); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), @@ -935,12 +933,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -952,12 +950,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -971,12 +969,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { -#if DEBUG_LOG - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -995,13 +993,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1009,13 +1007,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1024,13 +1022,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1038,13 +1036,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1053,14 +1051,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1068,25 +1067,26 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } if constexpr(is_same, bhalf_t>::value) { -#if DEBUG_LOG - std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } if(karg.KBatch > 1) { return false; @@ -1123,7 +1123,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 } template - __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( @@ -1141,26 +1141,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1508,12 +1504,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem) + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1521,11 +1516,42 @@ struct GridwiseGemm_xdl_cshuffle_v3 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1879,6 +1905,43 @@ struct GridwiseGemm_xdl_cshuffle_v3 }); } } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index ab3449b1c8596ddcb73fea9d3ae7fa22b5a154ce..f9071bd29d212d0ee8310d9efd3ca84b91a0ddf0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,8 +33,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -49,7 +48,7 @@ __global__ void karg.c_element_op); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -849,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -920,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -983,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), @@ -1113,12 +1111,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1130,12 +1128,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { -#if DEBUG_LOG - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1149,12 +1147,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { -#if DEBUG_LOG - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1173,13 +1171,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1187,13 +1185,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1202,13 +1200,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1216,13 +1214,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { -#if DEBUG_LOG - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1231,14 +1229,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } @@ -1246,14 +1245,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { -#if DEBUG_LOG - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a1ac6c6de485640bec20b7a9585a625b75e825b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -0,0 +1,2136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmMultiD_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 94306a4c958b6183c299a04169b93301000080d7..bac8c3288611001ee154e2b500ba64ff9979f909 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -38,8 +38,7 @@ __global__ void const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -52,7 +51,7 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea074144b66589bc861f6c57587bf9b658110046 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..33c07f34f77bac79ed6a728ec720e840054be830 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -0,0 +1,409 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_smfmac.hpp" + +namespace ck { + +enum struct SmfmacInstr +{ + smfmac_f32_16x16x32f16 = 0, + smfmac_f32_32x32x16f16, + smfmac_f32_16x16x32bf16, + smfmac_f32_32x32x16bf16, +}; + +template +struct smfmac_type; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16f16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32bf16::Run(a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16bf16::Run(a, b, idx, reg_c); + } +}; + +template +struct SmfmacSelector +{ + template + static constexpr auto GetSmfmac(); + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32bf16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16bf16; + } + + static constexpr auto selected_smfmac = + smfmac_type()>{}; + + __host__ __device__ constexpr SmfmacSelector() + { + static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk == + selected_smfmac.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks == + selected_smfmac.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks || + selected_smfmac.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size == + selected_smfmac.m_per_blk * selected_smfmac.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_smfmac.is_k_reduction || + (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) * + selected_smfmac.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; } +}; + +template +struct SparseXdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks); + } + + __host__ __device__ constexpr SparseXdlopsGemm() + { + static_assert(NPerXdlops == 16 || NPerXdlops == 32, + "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(MPerXdlops == 16 || MPerXdlops == 32, + "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk, + smfmac_instr.num_input_blks, + smfmac_instr.group_size)), + make_pass_through_transform(smfmac_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / smfmac_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; } + + template + __device__ void + Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value, + "base base_type must be half or bfloat16!"); + + static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k], p_c_thread); + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size; + + return CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto smfmac = + SmfmacSelector{}; + + static constexpr auto smfmac_instr = smfmac.selected_smfmac; + + static constexpr auto KPerXdlops = smfmac.GetKPerXdlops(); + static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 70fbcec10fa8603b5f23a90756280ea2776e2914..9a9ebf55951ad7f4f64b05de680e53ef8521ffdc 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -11,12 +11,17 @@ namespace ck { enum struct WmmaInstr { + // gfx11 wmma_f32_16x16x16_f16 = 0, wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, wmma_i32_16x16x16_iu8, - wmma_i32_16x16x16_iu4 + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, }; /* @@ -95,7 +100,7 @@ struct wmma_type{}; - // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // * Fixed on gfx11, Will be wave mode dependent for future architectures static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; // * num_acc_vgprs_per_wave alone M direction @@ -279,6 +284,122 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + template static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_f16; +#endif } template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif } template <> @@ -320,8 +449,13 @@ struct WmmaSelector template <> static constexpr auto GetWmma() { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> static constexpr auto GetWmma() @@ -502,6 +636,9 @@ struct WmmaGemm __device__ static auto GetSubGroupId() { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; } @@ -516,12 +653,20 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif } __device__ static CIndex GetBeginOfThreadBlk() diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc290d56413ffb5c7214511c95ec97b6a8b0c617 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -0,0 +1,640 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +/** + * @brief Transform conv bwd weight to gemm v2 + * + * This version does following things: + * 1. Merge KBatch with K0 to align descriptor with universal gemm + * 2. Merge Batch with M and N dimension. It allows to increase compute in + * case of small M and N. It also allows to vector load and store in case of + * K = 1, C = 1 and NHWGC layout. + */ +template +struct TransformConvBwdWeightToGemmV2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[4]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[5]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * Y * NumGroupsToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index e2f75142d4187c9581a8816f56ccb2f3608d64ab..8dd657301526f3fc5608a8005a12dee77c0c9ebe 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,11 +14,93 @@ namespace ck { namespace tensor_operation { -template +// function to be used on device, emulates std::accumulate +template +__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init) +{ + for(ForwardIterator x = first; x != first + count; x++) + { + init *= *x; + } + return init; +} + +template struct TransformConvFwdToGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static long_index_t + calculate_element_space_size_impl(const std::array& lengths, + const std::array& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static index_t GetSplitedNSize(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const index_t N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(index_t least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Not possible to support even after split N. + // Too large tensor. + return N; + } + } + else + { + // Split N is not needed. + return N; + } + } // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties @@ -38,7 +120,1076 @@ struct TransformConvFwdToGemm const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, - const std::array& input_right_pads) + const std::array& input_right_pads, + const index_t N) + { + const index_t C = a_g_n_c_wis_lengths[I2]; + + const index_t Wi = a_g_n_c_wis_lengths[I3]; + + const index_t Wo = c_g_n_k_wos_lengths[I3]; + + const index_t ConvStrideW = conv_filter_strides[I0]; + + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const auto CStride = a_g_n_c_wis_strides[I2]; + const index_t WiStride = a_g_n_c_wis_strides[I3]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationW = conv_filter_dilations[0]; + + const index_t InLeftPadW = input_left_pads[0]; + + const index_t InRightPadW = input_right_pads[0]; + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + const index_t X = b_g_k_c_xs_lengths[3]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, GStride, CStride)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(X, C))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t N) + + { + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Hi = a_g_n_c_wis_lengths[3]; + const index_t Wi = a_g_n_c_wis_lengths[4]; + + const index_t Ho = c_g_n_k_wos_lengths[3]; + const index_t Wo = c_g_n_k_wos_lengths[4]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t HiStride = a_g_n_c_wis_strides[I3]; + const index_t WiStride = a_g_n_c_wis_strides[I4]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + const index_t Y = b_g_k_c_xs_lengths[3]; + const index_t X = b_g_k_c_xs_lengths[4]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y, X, C))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + static auto + MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& /* b_g_k_c_xs_strides */, + const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides*/, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t N) + + { + const index_t C = a_g_n_c_wis_lengths[2]; + + const index_t Di = a_g_n_c_wis_lengths[3]; + const index_t Hi = a_g_n_c_wis_lengths[4]; + const index_t Wi = a_g_n_c_wis_lengths[5]; + + const index_t Do = c_g_n_k_wos_lengths[3]; + const index_t Ho = c_g_n_k_wos_lengths[4]; + const index_t Wo = c_g_n_k_wos_lengths[5]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t GStride = a_g_n_c_wis_strides[I0]; + const index_t NStride = a_g_n_c_wis_strides[I1]; + const index_t CStride = a_g_n_c_wis_strides[I2]; + const index_t DiStride = a_g_n_c_wis_strides[I3]; + const index_t HiStride = a_g_n_c_wis_strides[I4]; + const index_t WiStride = a_g_n_c_wis_strides[I5]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), + make_tuple(WiStride, CStride)); + } + else + { + const auto in_gemmm_groups_gemmk_desc = + make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C), + make_tuple(WiStride, GStride, CStride)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Number<3>{}, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Number<3>{}, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(Number<3>{}, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + const index_t Z = b_g_k_c_xs_lengths[3]; + const index_t Y = b_g_k_c_xs_lengths[4]; + const index_t X = b_g_k_c_xs_lengths[5]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z, Y, X, C))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = ck::accumulate_n( + b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const index_t GStride = b_g_k_c_xs_strides[I0]; + const index_t KStride = b_g_k_c_xs_strides[I1]; + const index_t CStride = b_g_k_c_xs_strides[I2]; + + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + using FilterSizeNumType = + std::conditional_t, + std::conditional_t, Number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(YX * C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template < + typename BLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const index_t K = b_g_k_c_xs_lengths[1]; + const index_t C = b_g_k_c_xs_lengths[2]; + + const index_t YX = ck::accumulate_n( + b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const index_t KStride = b_g_k_c_xs_strides[1]; + const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; + const auto CStride = I1; + + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& /* c_g_n_k_wos_strides */, + const index_t N) + { + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); + + return out_gemmm_gemmn_desc; + } + + template < + typename CLayout, + typename std::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const index_t N) + { + const index_t K = c_g_n_k_wos_lengths[2]; + + const index_t KStride = I1; + const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; + const index_t GStride = c_g_n_k_wos_strides[0]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NHoWo, K), + make_tuple(WoStride, KStride)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1), + make_tuple(WoStride, GStride, KStride, GStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + // for output bias + template , + bool>::type = false> + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const index_t N) + { + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; + + const index_t NHoWo = + N * ck::accumulate_n( + c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + + const auto out_gemmm_gemmn_desc = + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); + + return out_gemmm_gemmn_desc; + } + + // Overloaded functions for hipRTC purposes + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) { const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; @@ -141,17 +1292,17 @@ struct TransformConvFwdToGemm is_same_v || is_same_v), bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + __host__ __device__ static auto + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) { const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; @@ -271,16 +1422,16 @@ struct TransformConvFwdToGemm is_same_v), bool>::type = false> static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + MakeADescriptor_M_K(const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */, + const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads) { const index_t N = a_g_n_c_wis_lengths[1]; const index_t C = a_g_n_c_wis_lengths[2]; @@ -421,15 +1572,15 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */) + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& /* b_g_k_c_xs_strides */) { const index_t K = b_g_k_c_xs_lengths[1]; const index_t C = b_g_k_c_xs_lengths[2]; - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); const auto wei_gemmn_gemmk_desc = make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); @@ -446,14 +1597,15 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeBDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + __host__ __device__ static auto + MakeBDescriptor_N_K(const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides) { const index_t K = b_g_k_c_xs_lengths[1]; const index_t C = b_g_k_c_xs_lengths[2]; - const index_t YX = ck::accumulate_n( - b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + const index_t YX = + mult_accumulate_n(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1); const index_t KStride = b_g_k_c_xs_strides[1]; const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial]; @@ -476,16 +1628,15 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& /* c_g_n_k_wos_strides */) { const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K)); @@ -501,8 +1652,9 @@ struct TransformConvFwdToGemm is_same_v || is_same_v, bool>::type = false> - static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) { const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; @@ -511,8 +1663,7 @@ struct TransformConvFwdToGemm const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2]; const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride)); @@ -524,16 +1675,16 @@ struct TransformConvFwdToGemm template , bool>::type = false> - static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& c_g_n_k_wos_strides) + __host__ __device__ static auto + MakeCDescriptor_M_N(const ck::Array& c_g_n_k_wos_lengths, + const ck::Array& c_g_n_k_wos_strides) { const index_t N = c_g_n_k_wos_lengths[1]; const index_t K = c_g_n_k_wos_lengths[2]; const index_t KStride = c_g_n_k_wos_strides[2]; const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); + N * mult_accumulate_n(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1); const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); @@ -542,5 +1693,38 @@ struct TransformConvFwdToGemm } }; +// wrapper class to call member functions on TransformConvToGemm struct at runtime +// TODO: figure out aq way to properly pass in layout as an argument +struct TransformConv +{ + TransformConv() {} + + template + auto + transform_func(ck::Array out_lengths, + ck::Array out_strides, + TransformConvFwdToGemm conv_fwd_to_gemm) + { + if(NDimSpatial == 2) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 3) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(out_lengths, + out_strides); + } + else if(NDimSpatial == 1) + { + return conv_fwd_to_gemm.template MakeCDescriptor_M_N( + out_lengths, out_strides); + } + } +}; + } // namespace tensor_operation } // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 678c55b95f4e64f091752929dffb8221ba20197b..ab22134fc690011f9aa9f1b9222e434d22a04c5f 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -297,6 +297,17 @@ enum struct AmdBufferCoherenceEnum GLC = 1, SLC = 2, GLC_SLC = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, }; template @@ -980,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), - "s"(src_resource)); + "s"(src_resource) + : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp new file mode 100644 index 0000000000000000000000000000000000000000..abb8d9f5ef8cc689c83b6c3c0e064a3b3bbcf101 --- /dev/null +++ b/include/ck/utility/amd_smfmac.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#pragma once + +namespace ck { + +template +struct intrin_smfmac_f32_16x16x32f16; + +template <> +struct intrin_smfmac_f32_16x16x32f16<16, 16> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_16x16x32bf16; + +template <> +struct intrin_smfmac_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16f16; + +template <> +struct intrin_smfmac_f32_32x32x16f16<32, 32> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16bf16; + +template <> +struct intrin_smfmac_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 741b2975af6c5bf99346b1460018eac6fa33b21b..d6e1eab314e30184c669abe88f5a4cf7f5ea90c4 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -95,11 +95,33 @@ using get_carrier_t = typename get_carrier::type; } // namespace detail +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + __device__ inline int32_t amd_wave_read_first_lane(int32_t value) { return __builtin_amdgcn_readfirstlane(value); } +__device__ inline int64_t amd_wave_read_first_lane(int64_t value) +{ + constexpr unsigned object_size = sizeof(int64_t); + constexpr unsigned second_part_offset = object_size / 2; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) std::byte to_obj[object_size]; + + using Sgpr = uint32_t; + + *reinterpret_cast(to_obj) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj)); + *reinterpret_cast(to_obj + second_part_offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + second_part_offset)); + + return *reinterpret_cast(to_obj); +} + template < typename Object, typename = std::enable_if_t && std::is_trivially_copyable_v>> diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 1bb0140f3e2008c23cf589ff35e6090309e03802..322a0f94bb86552116449d155c3147d000eae3d4 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0ee52b9570f27b1a00916870e68f878009bcfc10..d8ccb2ea7620c02098d25332b53d8356af06dec8 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -4,7 +4,7 @@ #pragma once namespace ck { -// Define the common macro for MI300 models +// Define the common macro for gfx94x models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index f63ce5e5a07a796888cb60ae8da0c855df75e7ff..5366c56a9dfa7275ecca75d41daaf1a5cba6333d 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -36,6 +36,8 @@ struct Array return *this; } + __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } + __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } }; // empty Array diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 93a1edefb6def6ea4254df6331a2cbac51718886..4df14c6211bc644faa9997628a6caeb8e96bad4b 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -203,7 +203,7 @@ struct vector_type } }; -int static err = 0; +__device__ int static err = 0; template struct vector_type { diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6455402dcb331d91240ebe09d4d553f4d355f96e --- /dev/null +++ b/include/ck/utility/env.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace internal { +template +struct ParseEnvVal +{ +}; + +template <> +struct ParseEnvVal +{ + static bool parse_env_var_value(const char* vp) + { + std::string value_env_str{vp}; + + for(auto& c : value_env_str) + { + if(std::isalpha(c) != 0) + { + c = std::tolower(static_cast(c)); + } + } + + if(value_env_str == "disable" || value_env_str == "disabled" || value_env_str == "0" || + value_env_str == "no" || value_env_str == "off" || value_env_str == "false") + { + return false; + } + else if(value_env_str == "enable" || value_env_str == "enabled" || value_env_str == "1" || + value_env_str == "yes" || value_env_str == "on" || value_env_str == "true") + { + return true; + } + else + { + throw std::runtime_error("Invalid value for env variable"); + } + + return false; // shouldn't reach here + } +}; + +// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default). +// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string). +template <> +struct ParseEnvVal +{ + static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); } +}; + +template <> +struct ParseEnvVal +{ + static std::string parse_env_var_value(const char* vp) { return std::string{vp}; } +}; + +template +struct EnvVar +{ + private: + T value{}; + bool is_unset = true; + + public: + const T& GetValue() const { return value; } + + bool IsUnset() const { return is_unset; } + + void Unset() { is_unset = true; } + + void UpdateValue(const T& val) + { + is_unset = false; + value = val; + } + + explicit EnvVar(const char* const name, const T& def_val) + { + // NOLINTNEXTLINE (concurrency-mt-unsafe) + const char* vp = std::getenv(name); + if(vp != nullptr) // a value was provided + { + is_unset = false; + value = ParseEnvVal::parse_env_var_value(vp); + } + else // no value provided, use default value + { + value = def_val; + } + } +}; +} // end namespace internal + +// static inside function hides the variable and provides +// thread-safety/locking +// Used in global namespace +#define CK_DECLARE_ENV_VAR(name, type, default_val) \ + namespace ck::env { \ + struct name \ + { \ + static_assert(std::is_same_v, \ + "CK_DECLARE_ENV* must be used in the global namespace"); \ + using value_type = type; \ + static ck::internal::EnvVar& Ref() \ + { \ + static ck::internal::EnvVar var{#name, default_val}; \ + return var; \ + } \ + }; \ + } + +#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false) + +#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0) + +#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") + +#define CK_ENV(name) \ + ck::env::name {} + +template +inline const std::string& EnvGetString(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsEnabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsDisabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue(); +} + +template +inline uint64_t EnvValue(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsUnset(EnvVar) +{ + return EnvVar::Ref().IsUnset(); +} + +template +void EnvUnset(EnvVar) +{ + EnvVar::Ref().Unset(); +} + +/// updates the cached value of an environment variable +template +void UpdateEnvVar(EnvVar, const ValueType& val) +{ + static_assert(std::is_same_v); + EnvVar::Ref().UpdateValue(val); +} + +template +void UpdateEnvVar(EnvVar, const std::string_view& val) +{ + EnvVar::Ref().UpdateValue( + ck::internal::ParseEnvVal::parse_env_var_value(val.data())); +} + +} // namespace ck diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 2b921cdc7cc2c0d34158de8ccc7384898cba054d..d961cdb1981633b79efb77e37cd6c080fefea3b5 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -839,7 +839,7 @@ inline __device__ T rcp(T x) template inline __device__ T exp(T x) { - return ck::type_convert(__expf(ck::type_convert(x))); + return ck::type_convert(__ocml_exp_f32(ck::type_convert(x))); }; template <> @@ -851,7 +851,7 @@ inline __device__ half_t exp(half_t x) template <> inline __device__ float exp(float x) { - return __expf(x); + return __ocml_exp_f32(x); }; template <> diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 4fe5e39504da2f032bad12083e396a506c5c921d..d6b6eac26c06f74c661fa09a006b0a73a4b70214 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -10,12 +10,20 @@ namespace ck { __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else // asm volatile("\ // s_waitcnt lgkmcnt(0) \n \ // s_barrier \ // " ::); __builtin_amdgcn_s_waitcnt(0xc07f); __builtin_amdgcn_s_barrier(); +#endif #else __syncthreads(); #endif @@ -23,11 +31,20 @@ __device__ void block_sync_lds() __device__ void block_sync_lds_direct_load() { +#ifdef __gfx12__ + asm volatile("\ + s_wait_vmcnt 0x0 \n \ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt vmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif } __device__ void s_nop() diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index be74b1fdc10ee2dc0468a85106516223ad815ba1..382b9c555152352e37f783452b4ec8d8e3009840 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -8,7 +8,7 @@ #include "ck/utility/random_gen.hpp" namespace ck { -// Define the common macro for MI300 models +// Define the common macro for gfx94x models #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index bb19c9154bfa75fe9ce976c3c2f1838a8c5b0255..4cddf6faa94bc53700f81fab751816a860163531 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" @@ -26,6 +27,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" @@ -47,10 +49,12 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 53f42a74217647c61ad2756081dd269a60d13cf4..7f488d1b71e7d42b3e39d40fb9ea2991cfd53974 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,237 +26,346 @@ struct __attribute__((packed)) buffer_resource CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; - return __builtin_bit_cast(int32x4_t, res); + int32x4_t r = __builtin_bit_cast(int32x4_t, res); + r.x = __builtin_amdgcn_readfirstlane(r.x); + r.y = __builtin_amdgcn_readfirstlane(r.y); + r.z = __builtin_amdgcn_readfirstlane(r.z); + r.w = __builtin_amdgcn_readfirstlane(r.w); + return r; } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct buffer_load_trait; + +template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct buffer_load_trait<4 , T> { using payload_t = float; }; +template struct buffer_load_trait<2 , T> { using payload_t = float; }; +template struct buffer_load_trait<1 , T> { using payload_t = float; }; + +#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA +template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; +template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; +template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; +#endif +// clang-format on +} // namespace impl + // TODO: glc/slc/... -template +template struct buffer_load; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) -template <> -struct buffer_load<16> +template +struct buffer_load<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; - asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<8> +template +struct buffer_load<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; - asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<4> +template +struct buffer_load<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<2> +template +struct buffer_load<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = float; - asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<1> +template +struct buffer_load<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template +template struct buffer_load_if; -template <> -struct buffer_load_if<16> +template +struct buffer_load_if<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<8> +template +struct buffer_load_if<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x2_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<4> +template +struct buffer_load_if<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<2> +template +struct buffer_load_if<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<1> +template +struct buffer_load_if<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" @@ -270,17 +379,16 @@ struct buffer_store<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; - asm volatile( - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -291,17 +399,16 @@ struct buffer_store<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; - asm volatile( - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -312,17 +419,16 @@ struct buffer_store<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -333,17 +439,16 @@ struct buffer_store<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); using mbuf_t = short; - asm volatile( - "buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -354,17 +459,16 @@ struct buffer_store<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -378,21 +482,20 @@ struct buffer_store_if<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -407,7 +510,7 @@ struct buffer_store_if<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { @@ -415,14 +518,13 @@ struct buffer_store_if<8> auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -437,21 +539,20 @@ struct buffer_store_if<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -466,21 +567,20 @@ struct buffer_store_if<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -495,21 +595,20 @@ struct buffer_store_if<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -533,8 +632,9 @@ namespace impl{ template CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) { - static_for<0, b.size(), 1>{}([&](auto i){ - asm volatile(" " : : "v"(b.get(i)) : "memory"); + constexpr auto kSize = remove_cvref_t::size(); + static_for<0, kSize, 1>{}([&](auto i){ + asm volatile(" " : : "v"(b.get(number{})) : "memory"); }); } #if 1 @@ -764,6 +864,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); +// buffer store ui16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, @@ -854,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -CK_TILE_DEVICE void async_buffer_load_dword(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); + else + asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1176,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, @@ -1190,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, using type = thread_buffer; if constexpr(oob_conditional_check) { - buffer_load_if{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load_if{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } else { - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } } template + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0) + index_t src_immediate_addr_offset = 0, + bool_constant = {}) { static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 @@ -1473,6 +1623,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d static_cast(coherence)); } } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(uint16_t), + static_cast(coherence)); + } + } else { using r_t = thread_buffer; @@ -1590,7 +1783,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_th { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1816,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, - index_t is_valid_element = 0) + index_t is_valid_element = 0, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t is_valid_element = 0, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1838,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, // buffer_load OOB still working. template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_element_space_size) + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); @@ -1850,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); } // buffer_store requires: @@ -2016,7 +2257,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), - "s"(src_resource)); + "s"(src_resource) + : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 888f0e728ffee4e7c89b2fa689b7ac089eac521f..65a3a4e2fff317acc11d0ccacbda5a4d580826e6 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } CK_TILE_DEVICE void block_sync_lds() { #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); #else __syncthreads(); #endif @@ -79,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() " ::); } -CK_TILE_DEVICE void s_nop() +CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 - asm volatile("\ - s_nop 0 \n \ - " ::); + asm volatile("s_nop %0" : : "n"(cnt) :); #else - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(cnt); #endif } diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6212db9169ecc8070cdfc1916487c212d1320bf7 --- /dev/null +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" + +namespace ck_tile { + +CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) +{ + return type_convert(type_convert(a) + type_convert(b)); +} + +CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) +{ + bf16x2_t rtn; + rtn[0] = add_bf16_t(a[0], b[0]); + rtn[1] = add_bf16_t(a[1], b[1]); + return rtn; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); + +template <> +CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) +{ + union U32BF162_ADDR + { + uint32_t* u32_a; + bf16x2_t* bf162_a; + }; + + union U32BF162 + { + uint32_t u32; + bf16x2_t bf162; + }; + + U32BF162_ADDR dword_addr; + U32BF162 cur_v; + U32BF162 new_; + uint32_t old_v, new_v; + dword_addr.bf162_a = p_dst; + cur_v.u32 = *dword_addr.u32_a; + + do + { + old_v = cur_v.u32; + new_.bf162 = add_bf16x2_t(cur_v.bf162, x); + new_v = new_.u32; + cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); + } while(cur_v.u32 != old_v); +} + +template +CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 2 || N == 4)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + return atomicAdd(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicAdd(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicAdd(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicAdd(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 2) + { + atomic_add(c_style_pointer_cast(p_dst), bit_cast(x)); + } + else if constexpr(N == 4) + { + atomic_add(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomic_add(c_style_pointer_cast(p_dst) + 1, + x.template get_as()[I1]); + } + } +} + +template +CK_TILE_DEVICE void atomic_max_g(T* p_dst, const thread_buffer& x) +{ + static_assert((std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1)) || + (std::is_same::value && (N == 1 || N == 2)) || + (std::is_same::value && (N == 1)), + "wrong! not implemented"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + else if constexpr(N == 2) + { + atomicMax(c_style_pointer_cast(p_dst), x.template get_as()[I0]); + atomicMax(c_style_pointer_cast(p_dst) + 1, x.template get_as()[I1]); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + atomicMax(p_dst, bit_cast(x)); + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index d915df6e4c779b6fc60bebfd86f521673af9bb88..fa28aa2be95efa2e1fb9c0c6db9d17359ec09c76 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,6 +3,25 @@ #pragma once +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif +#if defined(__gfx1200__) || defined(__gfx1201__) +#define __gfx12__ +#endif + +#include "hip/hip_version.h" #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -109,15 +128,13 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -131,19 +148,26 @@ #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #endif +#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE +#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091 +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \ + defined(__gfx9__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(__gfx1030__) // for GPU code +#elif defined(__gfx103__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -154,3 +178,16 @@ #ifndef CK_TILE_USE_SUBDWORD_TILE_CAST #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif + +#ifndef CK_TILE_USE_PK_FP16_TILE_CAST +#define CK_TILE_USE_PK_FP16_TILE_CAST 0 +#endif + +// TODO: better solve this inside compiler +#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 +#define CK_TILE_FMHA_FWD_FAST_EXP2 0 +#endif + +#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA +#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 +#endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 071387163a95955a3e90fe8d70eddd38260db05d..4fdf8f9daedef7b8470c8534ced969a1b1c44dcb 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x) }; CK_TILE_DEVICE -bfloat16_t exp(bfloat16_t x) { return static_cast(__expf(static_cast(x))); }; +bfloat16_t exp(bfloat16_t x) +{ + return static_cast(__ocml_exp_f32(static_cast(x))); +}; CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index bad1009f2c0e3356c044e64e197c66cd7c8a65d2..b3b1a1f3fb2faf3f12d1fc5d8dc7a099bb4bf563 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t { static constexpr int exponent = 4; static constexpr int mantissa = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE @@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t { static constexpr int exponent = 5; static constexpr int mantissa = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE @@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) } CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); @@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); @@ -656,7 +656,7 @@ struct numeric_traits { static constexpr int exp = 4; static constexpr int mant = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 8; #else static constexpr int bias = 7; @@ -668,7 +668,7 @@ struct numeric_traits { static constexpr int exp = 5; static constexpr int mant = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 16; #else static constexpr int bias = 15; // IEEE @@ -835,7 +835,7 @@ CK_TILE_DEVICE fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -fp8_t exp(fp8_t x) { return static_cast(__expf(static_cast(x))); }; +fp8_t exp(fp8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); }; @@ -860,7 +860,7 @@ CK_TILE_DEVICE bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; CK_TILE_DEVICE -bf8_t exp(bf8_t x) { return static_cast(__expf(static_cast(x))); }; +bf8_t exp(bf8_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index c616b6939f5c6c2f672abee9b1df097595d39306..acb6eb6c3e016c7a16aef31282dea4979ee45f7b 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x) CK_TILE_HOST_DEVICE constexpr fp16_hip_t float_to_fp16_hip(const float& x) { - return __float2half(x); - // return static_cast(x); + // return __float2half(x); + return static_cast(x); } CK_TILE_HOST_DEVICE @@ -374,7 +374,7 @@ half_t sqrt(half_t x) }; CK_TILE_DEVICE -half_t exp(half_t x) { return static_cast(__expf(static_cast(x))); }; +half_t exp(half_t x) { return static_cast(__ocml_exp_f32(static_cast(x))); }; CK_TILE_DEVICE half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))); }; diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index 1166fcc3bcc340a3f8df901683ba2e4776c455fc..ff27108594990b8cc6a262c92c0c95df48bfd99f 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+) CK_TILE_LEFT_UNARY_OP(-) CK_TILE_LEFT_UNARY_OP(~) CK_TILE_LEFT_UNARY_OP(!) -CK_TILE_LEFT_UNARY_OP(*) CK_TILE_BINARY_OP(+) CK_TILE_BINARY_OP(-) diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 72ec607b42e99036a002d0785b9523292d83091e..9970bb36930f3cb1033d00d83b058237a2d9d18b 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -519,7 +519,7 @@ CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; CK_TILE_DEVICE -float exp(float x) { return __expf(x); }; +float exp(float x) { return __ocml_exp_f32(x); }; CK_TILE_HOST float exp(float x) { return std::expf(x); } @@ -536,4 +536,15 @@ float log(float x) { return __logf(x); }; CK_TILE_HOST float log(float x) { return std::logf(x); }; +CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + // TODO: this is hacky, we use u16 + return __builtin_amdgcn_sad_u16(x, y, acc); +} + +CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + return (x > y ? (x - y) : (y - x)) + acc; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/null_type.hpp b/include/ck_tile/core/numeric/null_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8799c0560ea9598c0ba39bbe4886e9ab9f438630 --- /dev/null +++ b/include/ck_tile/core/numeric/null_type.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +struct null_type +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 85d9be1c948c69bdc89b35234ed9d3fba4ca57f4..c23c12f29574bb4b85e88de8882e0a152f237f3b 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); +// ui8 +// using uint8_t +using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); +using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); +using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); +using uint8x16_t = uint8_t __attribute((ext_vector_type(16))); +using uint8x32_t = uint8_t __attribute((ext_vector_type(32))); +using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); + #if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 96b38241c0c289db227cee52893b748f76c99323..ed705c91e72e5c506fddb3bcdad1360575760fd1 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" @@ -68,6 +69,8 @@ struct buffer_view invalid_element_value_ = T{0}; CK_TILE_HOST_DEVICE constexpr buffer_view() - : p_data_{}, buffer_size_{}, invalid_element_value_{} + : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size, T invalid_element_value) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + : p_data_{p_data}, + buffer_size_{buffer_size}, + cached_buf_res_{0}, + invalid_element_value_{invalid_element_value} + { + } + + // this is non constexpr intentially (will call some intrinsic internally) + // Must call for buffers that need *_raw load/store + CK_TILE_HOST_DEVICE void init_raw() { + cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); } CK_TILE_DEVICE static constexpr address_space_enum get_address_space() @@ -332,12 +346,15 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t i, + bool is_valid_element, + bool_constant = {}) const { constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -348,18 +365,21 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - dst, p_data_, i, buffer_size_, is_valid_element); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check, pre_nop>( + dst, cached_buf_res_, i, is_valid_element, bool_constant{}); } // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, + index_t i, + bool /*is_valid_element*/, + bool_constant = {}) const { // X is vector of T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -370,8 +390,8 @@ struct buffer_view, t_per_x, Coherence>( - smem, p_data_, i, buffer_size_); + amd_async_buffer_load_with_oob_raw, t_per_x, Coherence>( + smem, cached_buf_res_, i, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -507,10 +527,10 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } @@ -518,7 +538,7 @@ struct buffer_view(c_style_pointer_cast(&p_data_[i]), x); + atomic_add_g, t_per_x>(&p_data_[i], x); } } } @@ -547,16 +567,16 @@ struct buffer_view, t_per_x>( x, p_data_, i, is_valid_element, buffer_size_); } else if(is_valid_element) { - atomic_max(c_style_pointer_cast(&p_data_[i]), x); + atomic_max_g, t_per_x>(&p_data_[i], x); } } @@ -626,6 +646,8 @@ struct buffer_view + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto load_tile_raw(T& tile, const tile_window_with_static_distribution& tile_window, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}); + tile_window.load_raw(tile, bool_constant{}, bool_constant{}); } template + index_t NumCoord, + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const tile_window_with_static_distribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}, + bool_constant = {}) { - return tile_window.async_load(lds_tile); + return tile_window.async_load_raw( + lds_tile, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 89806203abb085cf7baa3f8505b49067446f7275..9707f2990a5bcaaaa1961b89c66ba86040ced516 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -35,6 +35,8 @@ struct null_tile_window CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + CK_TILE_DEVICE void init_raw() {} + WindowLengths window_lengths_; }; diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index c12ad883d93b681ba12b388754bc133498877f76..2efc65701395a0054a7abf27865c10b4d67e45f5 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index e37bd806de593a6436bcd430f1b198d895bec965..4655eec24156e3bc38e6a9c07b77d4a58ac94ab4 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -16,7 +16,9 @@ namespace ck_tile { -template +template struct tensor_view { using buffer_view = remove_reference_t; @@ -24,6 +26,7 @@ struct tensor_view using TensorDesc = remove_cvref_t; using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + static constexpr auto DstInMemOp = DstInMemOp_; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -33,6 +36,8 @@ struct tensor_view { } + CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); } + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() @@ -82,30 +87,34 @@ struct tensor_view // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE void - get_vectorized_elements_raw(remove_cvref_t& dst, - const TensorCoord& coord, - bool_constant = {}) const + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}, + bool_constant = {}) const { - return buf_.template get_raw( + return buf_.template get_raw( dst, coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); } template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, - const TensorCoord& coord) const + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( + remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const { - return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + return buf_.template async_get_raw( + smem, coord.get_offset(), true /*not used*/, bool_constant{}); } // X is vector of DataType. @@ -140,6 +149,23 @@ struct tensor_view x); } + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template update( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); @@ -178,6 +204,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, } template (p, desc.get_element_space_size()); - return tensor_view{buffer_view, desc}; + return tensor_view{buffer_view, desc}; } template >{ - old_tensor_view.buf_, new_desc}; + return tensor_view, + remove_cvref_t::DstInMemOp>{old_tensor_view.buf_, new_desc}; } template -CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +template +CK_TILE_DEVICE void +set_tile(DstrTensors& dstr_tensor, number, bool_constant = {}) { - constexpr index_t tensor_bytes = - DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); - if constexpr(v == 0 && tensor_bytes % 4 == 0) + using elem_type = typename DstrTensors::DataType; + constexpr index_t elem_size = sizeof(elem_type); + + constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size; + + // # bytes per write = 4 + if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt) { +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + auto& buffer = dstr_tensor.get_thread_buffer(); + + static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) { + if constexpr(elem_size == 1) + { + // # elements per write = 4 + constexpr auto values = ext_vector_t{0, 0, 0, 0}; + + buffer[i_write * 4 + 0] = values.x; + buffer[i_write * 4 + 1] = values.y; + buffer[i_write * 4 + 2] = values.z; + buffer[i_write * 4 + 3] = values.w; + } + else if constexpr(elem_size == 2) + { + // # elements per write = 2 + constexpr auto values = ext_vector_t{0, 0}; + + buffer[i_write * 2 + 0] = values.x; + buffer[i_write * 2 + 1] = values.y; + } + else if constexpr(elem_size == 4) + { + // # elements per write = 1 + constexpr elem_type value = 0; + + buffer[i_write] = value; + } + else + { + static_assert(false, "type not supported"); + } + }); +#else using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; +#endif } else { - tile_elementwise_inout( - [](auto& x) { x = type_convert(v); }, - dstr_tensor); + tile_elementwise_inout([](auto& x) { x = type_convert(v); }, + dstr_tensor); } } @@ -110,9 +150,9 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) namespace impl { // TODO: this is ugly template -CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) +CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); @@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) #endif } +template +CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + // TODO: this is rtz cvt, need be very careful + for(index_t i = 0; i < thread_buffer_size_pk; i++) + { + auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0], + in_dstr_tensors.get_thread_buffer()[2 * i + 1]); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y; + } + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) float> && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { - return impl::cast_tile_pk_fp8x4(src_tensor); + return impl::cast_tile_pk_fp8_fp32(src_tensor); } +#if CK_TILE_USE_PK_FP16_TILE_CAST + else if constexpr(std::is_same_v && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + { + return impl::cast_tile_pk_fp16_fp32(src_tensor); + } +#endif #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 48ae6fec7316deaf9aec4661993344c58163273b..ebf7d4fb36a4c8ee2a85286f08d7605f93be1c07 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -355,9 +355,10 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -384,7 +385,13 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); @@ -395,7 +402,8 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, - bool_constant{}); + bool_constant{}, + pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -410,12 +418,17 @@ struct tile_window_with_static_distribution } }); }); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile("; this inline asm is workaround to prevent compiler from using too much " + "scratch memory" ::); +#endif } // TODO: currently async load only implemented in inline asm - template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - bool_constant = {}) const + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -460,11 +473,17 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord); + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -605,6 +624,66 @@ struct tile_window_with_static_distribution }); } + template + CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin @@ -619,6 +698,67 @@ struct tile_window_with_static_distribution }); } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fbce7c40839bcdf2df329b6fe0450ad62be6b9b3 --- /dev/null +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.update(dstr_tensor); +} + +template +CK_TILE_DEVICE void +update_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.update(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c49f44ae48f65564fe94499ff0e30df693beaada --- /dev/null +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh +class philox +{ + public: + CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_) + : seed(reinterpret_cast(seed_)) + { + + ull2* tmp = reinterpret_cast(&counter); + tmp->x = offset_; + } + + CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const + { + + uint4 counter_ = counter; + ull2* tmp = reinterpret_cast(&counter_); + tmp->y = subsequence; + + uint2 key_ = seed; +// 7-round philox +#pragma unroll + for(int i = 0; i < 6; i++) + { + counter_ = philox_single_round(counter_, key_); + key_.x += kPhilox10A; + key_.y += kPhilox10B; + } + uint4 output = philox_single_round(counter_, key_); + return output; + } + + CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t* out, + const unsigned long long subsequence) const + { + uint4 tmp_ph; + tmp_ph = get_philox_4x32(subsequence); + + uint32_t* out_tmp = reinterpret_cast(&out[0]); + + out_tmp[0] = tmp_ph.x; + out_tmp[1] = tmp_ph.y; + out_tmp[2] = tmp_ph.z; + out_tmp[3] = tmp_ph.w; + } + + private: + struct ull2 + { + uint64_t x; + uint64_t y; + }; + uint4 counter; + const uint2 seed; + + CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const + { + uint2* res; + unsigned long long tmp; + tmp = static_cast(a) * b; + res = reinterpret_cast(&tmp); + return *res; + } + + CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const + { + + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 0c4a7782269ba46ed0292c69e9ffa669491e1c7f..0e69a925d510b72d3a16ad28460337eee5be2e5d 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,12 +11,15 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" +#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" +#include "ck_tile/host/reference/reference_layernorm2d.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef9b24138f5ac0c00eb3548c6355482bcb56bc3..529bfdff25fd3201c48aeaa99766d032f771a69d 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -56,8 +56,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -114,8 +115,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -173,8 +175,9 @@ check_err(const Range& out, } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; @@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v, ranges::range_val } const auto is_infinity_error = [=](auto o, auto r) { - const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); - const bool both_infinite_and_same = std::isinf(o) && std::isinf(r) && (o == r); + const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r); + const bool both_infinite_and_same = + std::isinf(o) && std::isinf(r) && (bit_cast(o) == bit_cast(r)); return either_not_finite && !(allow_infinity_ref && both_infinite_and_same); }; diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 91463a06a985457b610b19587f5bfb47510992fd..7c8549f74fe920f7f581de6b21645e96d4835296 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -27,7 +27,14 @@ struct DeviceMem DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void Realloc(std::size_t mem_size) { @@ -36,7 +43,14 @@ struct DeviceMem HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void* GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t GetBufferSize() const { return mMemSize; } @@ -47,15 +61,18 @@ struct DeviceMem HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } - else - { - throw std::runtime_error("ToDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("ToDevice with an empty pointer"); + // } } void ToDevice(const void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR( - hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + } } void FromDevice(void* p) const { @@ -63,14 +80,17 @@ struct DeviceMem { HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } - else - { - throw std::runtime_error("FromDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("FromDevice with an empty pointer"); + // } } void FromDevice(void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } } void SetZero() const { @@ -82,13 +102,16 @@ struct DeviceMem template void SetValue(T x) const { - if(mMemSize % sizeof(T) != 0) + if(mpDeviceBuf) { - throw std::runtime_error("wrong! not entire DeviceMem will be set"); - } + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } - // TODO: call a gpu kernel to set the value (?) - set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + // TODO: call a gpu kernel to set the value (?) + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } } ~DeviceMem() { diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index cd0dc382591fb339a9b8beec61f68a2bd60276c6..43405ee69b3321617dc890fc67047858eed5ae44 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,7 +156,7 @@ struct HostTensorDescriptor } const std::vector& get_lengths() const { return mLens; } - const std::vector& GetStrides() const { return mStrides; } + const std::vector& get_strides() const { return mStrides; } template std::size_t GetOffsetFromMultiIndex(Is... is) const @@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old for(std::size_t i = 0; i < a.get_num_of_dimension(); i++) { new_lengths[i] = a.get_lengths()[new2old[i]]; - new_strides[i] = a.GetStrides()[new2old[i]]; + new_strides[i] = a.get_strides()[new2old[i]]; } return HostTensorDescriptor(new_lengths, new_strides); @@ -327,7 +327,7 @@ struct HostTensor decltype(auto) get_lengths() const { return mDesc.get_lengths(); } - decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + decltype(auto) get_strides() const { return mDesc.get_strides(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } @@ -481,6 +481,34 @@ struct HostTensor return mData[mDesc.GetOffsetFromMultiIndex(idx)]; } + HostTensor transpose(std::vector axes = {}) const + { + if(axes.empty()) + { + axes.resize(this->get_num_of_dimension()); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if(axes.size() != mDesc.get_num_of_dimension()) + { + throw std::runtime_error( + "HostTensor::transpose(): size of axes must match tensor dimension"); + } + std::vector tlengths, tstrides; + for(const auto& axis : axes) + { + tlengths.push_back(get_lengths()[axis]); + tstrides.push_back(get_strides()[axis]); + } + HostTensor ret(*this); + ret.mDesc = HostTensorDescriptor(tlengths, tstrides); + return ret; + } + + HostTensor transpose(std::vector axes = {}) + { + return const_cast const*>(this)->transpose(axes); + } + typename Data::iterator begin() { return mData.begin(); } typename Data::iterator end() { return mData.end(); } diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 7053888abd35980a4aa320a1a3988521e98ac4de..e9c5a0c25491ec39790daf03ec274a2c37a9a6f2 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/timer.hpp" #include #include @@ -14,153 +15,92 @@ template -CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) +// +// return a anonymous functor(lambda) to be called later +// the KernelImpl should be a class without non-static data member, or let's say +// can be instantiate with "KernelImpl{}" +// +// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl +// +template +CK_TILE_HOST auto +make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { - // warm up - for(int i = 0; i < s.cold_niters_; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - const int nrepeat = s.nrepeat_; - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + const auto kernel = kentry; - return total_time / nrepeat; - } - else - { + return [=](const stream_config& s) { kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; -#endif + }; } -template -CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, - PreProcessFunc preprocess, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) +// clang-format off +/* + * launch_kernel() + * + * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config) + * the callables should have signature as "operator()(const stream_config& s){ ... }" to call + * + * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }" + * as signature, for the callable (pay attention to the capture list) + * + * e.g. + * ck_tile::launch_kernel(s, + * [=](const stream_config& s){ hipMemset(ptr, 0, size) }, + * [=](const stream_config& s){ some_kernel<<>>(arg); } + * ); + * + * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}") + * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you, + * then pass it to ck_tile::launch_kernel() + * + * e.g. + * ck_tile::launch_kernel(s, + * ck_tile::make_kernel(kernel_0{}, grids0, blocks0, 0, kargs0), + * ck_tile::make_kernel(kernel_1{}, grids1, blocks1, 0, kargs1), + * ...); + **/ +// clang-format on +template +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) { -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { -#if CK_TILE_DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif - // warm up - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - const int nrepeat = 10; -#if CK_TILE_DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); + // clang-format off + if(!s.time_kernel_) { + (callables(s),...); hip_check_error(hipGetLastError()); + return 0; + } + if(s.is_gpu_timer_) { + gpu_timer timer {}; - for(int i = 0; i < nrepeat; ++i) - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); - float total_time = 0; + return timer.duration() / s.nrepeat_; + } + else { + cpu_timer timer {}; - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); - return total_time / nrepeat; - } - else - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); - return 0; + return timer.duration() / s.nrepeat_; } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; -#endif + // clang-format on } -template -CK_TILE_HOST float launch_kernel(const stream_config& s, - KernelImpl kernel_impl, - dim3 grid_dim, - dim3 block_dim, - std::size_t dynamic_smem_byte, - Args... args) -{ - const auto kernel = kentry; - - return launch_and_time_kernel( - s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...); -} } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_dropout.hpp b/include/ck_tile/host/reference/reference_batched_dropout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..242101bf4dd3dc26a11ca7565e6ba5fee5d080c3 --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_dropout.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +CK_TILE_HOST void reference_batched_dropout(HostTensor& in_out_b_m_n, + const HostTensor& randval_b_m_n, + const uint8_t& p_undrop_in_uint8_t, + const float scale) +{ + const int N = in_out_b_m_n.mDesc.get_lengths()[2]; + auto f = [&](auto batch, auto m) { + for(int n = 0; n < N; ++n) + { + float tmp = ck_tile::type_convert(in_out_b_m_n(batch, m, n)) * scale; + in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t + ? ck_tile::type_convert(tmp) + : DataType(0); + } + }; + + make_ParallelTensorFunctor( + f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_layernorm2d.hpp b/include/ck_tile/host/reference/reference_layernorm2d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..837f52c399a5d81dcd103e9e2c92c4ed953782a4 --- /dev/null +++ b/include/ck_tile/host/reference/reference_layernorm2d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_layernorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + const HostTensor& beta_n, + HostTensor& y_m_n, + HostTensor& mean_m, + HostTensor& invStd_m, + ComputeDataType epsilon) +{ + auto layernorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + int count = 0; + ComputeDataType mean = 0; + ComputeDataType variance = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ++count; + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType delta = x - mean; + mean += delta / count; + ComputeDataType delta2 = x - mean; + variance += delta * delta2; + } + + // actual variance + variance = variance / count; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(variance + epsilon); + + if constexpr(!std::is_same_v) + mean_m(m) = ck_tile::type_convert(mean); + + if constexpr(!std::is_same_v) + invStd_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + ComputeDataType beta = ck_tile::type_convert(beta_n(n)); + auto y = (x - mean) * divisor; + y = y * gamma + beta; + + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(layernorm2d_fwd_func, + mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index d29c6f0fa14b95da1b41a472174d8c0284c3516a..47cf0fd5e414f018380c3c890992658bc04293b8 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -6,6 +6,22 @@ #include namespace ck_tile { +/* + * construct this structure with behavior as: + * + * // create stream config with default stream(NULL), and not timing the kernel + * stream_config s = stream_config{}; + * + * // create stream config with _some_stream_id_, and not timing the kernel + * stream_config s = stream_config{_some_stream_id_}; + * + * // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default + * stream_config s = stream_config{_some_stream_id_, true}; + * + * // create stream config with _some_stream_id_, and benchmark using cpu timer + * stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false}; + **/ + struct stream_config { hipStream_t stream_id_ = nullptr; @@ -13,5 +29,6 @@ struct stream_config int log_level_ = 0; int cold_niters_ = 3; int nrepeat_ = 10; + bool is_gpu_timer_ = true; // keep compatible }; } // namespace ck_tile diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5519643bfdc050bf07a3cfe65bda773c0e2a371 --- /dev/null +++ b/include/ck_tile/host/timer.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include +#include + +namespace ck_tile { + +struct gpu_timer +{ + CK_TILE_HOST gpu_timer() + { + HIP_CHECK_ERROR(hipEventCreate(&start_evt)); + HIP_CHECK_ERROR(hipEventCreate(&stop_evt)); + } + + CK_TILE_HOST ~gpu_timer() noexcept(false) + { + HIP_CHECK_ERROR(hipEventDestroy(start_evt)); + HIP_CHECK_ERROR(hipEventDestroy(stop_evt)); + } + + CK_TILE_HOST void start(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipStreamSynchronize(s)); + HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); + } + + CK_TILE_HOST void stop(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipEventRecord(stop_evt, s)); + HIP_CHECK_ERROR(hipEventSynchronize(stop_evt)); + } + // return in ms + CK_TILE_HOST float duration() const + { + float ms = 0; + HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt)); + return ms; + } + + private: + hipEvent_t start_evt, stop_evt; +}; + +struct cpu_timer +{ + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void start(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipStreamSynchronize(s)); + start_tick = std::chrono::high_resolution_clock::now(); + } + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void stop(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipStreamSynchronize(s)); + stop_tick = std::chrono::high_resolution_clock::now(); + } + // return in ms + CK_TILE_HOST float duration() const + { + double sec = + std::chrono::duration_cast>(stop_tick - start_tick) + .count(); + return static_cast(sec * 1e3); + } + + private: + std::chrono::time_point start_tick; + std::chrono::time_point stop_tick; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index c567e63ddfce88da645627bd9febaa2303fe720f..057d2b11ff79fd0200a0199ffbd14fe0a09ae169 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,9 +3,35 @@ #pragma once +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e5be21e0489e661c592de18093121d2c46c1b890 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionBiasEnum +{ + NO_BIAS = 0, + ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) + ALIBI = 2, // bias computed with position encoding, applied after scale +}; + +template +struct BlockAttentionBiasEnumToStr; + +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "bias"; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "alibi"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7ebb306cce568031dbf2cc4af05a2b34e2c21694 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +struct NullBlockDropout +{ + template + __host__ __device__ static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + (void)randval_dram_block_window_tmp; + (void)seqlen_qk_start; + + return make_null_tile_window(make_tuple(number<0>{}, number<0>{})); + } +}; + +struct BlockDropout +{ + CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, + index_t i_head, + index_t nheads, + unsigned long long seed, + unsigned long long offset, + float rp_undrop_, + uint8_t p_undrop_in_uint8_t_, + bool is_store_randval_) + : ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()), + rp_undrop(rp_undrop_), + p_undrop_in_uint8_t(p_undrop_in_uint8_t_), + is_store_randval(is_store_randval_) + { + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + index_t seqlen_qk_start) + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + const auto block_origin = randval_dram_block_window_tmp.get_window_origin(); + auto randval_dram_window = [&]() { + if constexpr(IsFwd) + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N + } + else + { + return make_tile_window( + randval_dram_block_window_tmp.get_bottom_tensor_view(), + ck_tile::make_tuple(number{}, number{}), + {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N + } + }(); + + return randval_dram_window; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + constexpr index_t kN1 = 8; + constexpr index_t kN0 = kNPerStep / kN1; + + constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor( + ck_tile::make_tuple(number{}, number{}, number{}), + ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto randval_lds_block_desc = transform_tensor_descriptor( + randval_lds_block_desc_0, + ck_tile::make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(ck_tile::make_tuple(number{}, number{}))), + ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}), + ck_tile::make_tuple(sequence<0>{}, sequence<1>{})); + + return randval_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd. + constexpr auto randval_block_inner_part_dstr_encoding = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + else + { + return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{}; + } + }(); + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + randval_block_inner_part_dstr_encoding); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution() + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = 1; + constexpr index_t NIterPerWarp = 1; + + constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto randval_block_part_dstr_encode = + detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding, + typename WG::CWarpDstrEncoding{}); + + return make_static_tile_distribution(randval_block_part_dstr_encode); + } + + template + CK_TILE_HOST_DEVICE void Run(void* randval_ptr, + const index_t start_n0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // randval tile in LDS + auto randval_lds = make_tensor_view( + reinterpret_cast(randval_ptr), MakeRandValLdsBlockDescriptor()); + + auto randval_lds_window = make_tile_window( + randval_lds, MakeRandValLdsBlockDescriptor().get_lengths(), {0, 0}); + + // register distribute + auto randval_dist_generated = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval_dist_generated.kThreadElementSpaceSize == 16); + + auto randval_lds_read_window = + make_tile_window(randval_lds_window.get_bottom_tensor_view(), + randval_lds_window.get_window_lengths(), + randval_lds_window.get_window_origin(), + MakeRandValLdsShuffleTileDistribution()); + + const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{}); + if(is_store_randval) + { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, + reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + // save to Global + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {0, kNPerStep}); + }); + move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock}); + }); + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock}); + }; + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id(); + int block_col_start = (start_n0_idx / WG::kN) + i_n0; + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_dist_generated_spans = + decltype(randval_dist_generated)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1); + randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++]; + }); + }); + // save to LDS + store_tile(randval_lds_window, randval_dist_generated); + block_sync_lds(); + // read from LDS to register + auto randval = load_tile(randval_lds_read_window); + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto p_idx0 = tile_distributed_index{}; + constexpr auto p_idx1 = + tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] * rp_undrop + : PComputeDataType(0); + }); + }); + }); + }); + } + + template + CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, + PComputeWindow& p_compute, + RandValDramWindow& randval_dram_window) const + { + constexpr auto config = + BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + using BlockGemmShape = remove_cvref_t; + constexpr index_t kMPerBlock = BlockGemmShape::kM; + constexpr index_t kNPerBlock = BlockGemmShape::kN; + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = NWarp * WG::kN; + + // register distribute + auto randval = + make_static_distributed_tensor(MakeRandValTileDistribution()); + static_assert(randval.kThreadElementSpaceSize == 16); + + const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); + static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { + static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { + int block_row_start = (start_m0_idx / WG::kM) + i_m0; + int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); + uint2 rowcol = make_uint2(block_row_start, block_col_start); + + // generate random number + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, reinterpret_cast(rowcol)); + + constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); + int i_random_idx = 0; + sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { + constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); + randval(r_idx) = random_uint8_t[i_random_idx++]; + constexpr auto p_idx0 = + tile_distributed_index{}; + constexpr auto p_idx1 = tile_distributed_index{}; + constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); + p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t + ? p_compute[p_idx] + : -p_compute[p_idx]; + }); + }); + // save to Global + if(is_store_randval) + { + const auto randval_store = cast_tile(randval); + store_tile(randval_dram_window, randval_store); + move_tile_window(randval_dram_window, {kMPerStep, 0}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); + } + }); + if(is_store_randval) + { + move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); + } + } + + ck_tile::philox ph; + const float rp_undrop; + const uint8_t p_undrop_in_uint8_t; + const bool is_store_randval; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 7fb1c19b5f31ce68c29f385906922b4abde39462..c022edf723d7e071eae6f96b83292c8215b53a05 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -141,6 +141,36 @@ struct GenericAttentionMask } } + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -160,14 +190,14 @@ struct GenericAttentionMask } else { - return i_x >= x_end; + return i_x >= x_end || i_y >= y_total; } } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -269,6 +299,53 @@ struct SimplifiedGenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + + const index_t x_per_split = ck_tile::max(1, x_total / num_splits); + const index_t split_start = x_per_split * i_split; + const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); + + return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), + ck_tile::min(origin_end, split_end)); + } + + // to get the loop length along Y axis, return index:[start, end), end-start=length + // use this if need loop over Y axis tile by tile (like q-seqlen loopover) + // TODO: y_end still could be negative, so end-start could be negative(need check) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongY(index_t i_x, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, y_total); + } + else + { + // get the tile start/end range assum we loop over along Y tile by tile + index_t y_start = [&]() { + index_t tmp = max(-x + i_x + 1, 0); + return (tmp / YTile) * YTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t y_end = [&]() { + index_t tmp = min(i_x + XTile - 1 + y, y_total); + return ((tmp + YTile - 1) / YTile) * YTile; + }(); + + return ck_tile::make_tuple(y_start, y_end); + } + } + // per-pixel check if out-of-bound, if true, need mask a value(like -INF) CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const { @@ -283,13 +360,13 @@ struct SimplifiedGenericAttentionMask index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end; + return i_x < x_start || i_x >= x_end || i_y >= y_total; } } // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel - // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX() + // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() // can be used as a fast-path to decide if do per-pixel check or not template CK_TILE_HOST_DEVICE constexpr auto @@ -312,7 +389,7 @@ struct SimplifiedGenericAttentionMask // index_t x_end = min(i_y + x, x_total); bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad - bool bottom_left_edge = i_y_end > (i_x + y); + bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now return top_right_edge || bottom_left_edge; @@ -361,6 +438,6 @@ make_generic_attention_mask_from_lr_window(index_t left_size, { auto r = make_generic_attention_mask_coordinates_from_lr_window( left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total}; + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2fdaf3a1afedd4556ef37438a4c0d3b3bc7c424 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include +#include + +namespace ck_tile { + +enum struct PositionEncodingEnum +{ + NO = 0, + ALIBI = 1, +}; + +/* +VERTICAL: + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + +TOP_LEFT(but negative): + [0] 1 2 3 4 5 + 1 [0] 1 2 3 4 + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + +FROM_BOTTOM_RIGHT(but negative): + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + 4 3 2 1 [0] 1 + 5 4 3 2 1 [0] +*/ + +enum struct AlibiMode +{ + VERTICAL = 0, + FROM_TOP_LEFT = 1, // keep sync with mask enum + FROM_BOTTOM_RIGHT = 2, +}; + +template +struct Alibi +{ + // RowMajor here means if pixel within the same thread are along the row, or col + // this may impact the performance of update(), while the result are the same. + // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false + CK_TILE_HOST_DEVICE Alibi(DataType slope_, + index_t y_total_, + index_t x_total_, + AlibiMode mode_ = AlibiMode::VERTICAL) + { + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_; + + shift_left_up = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + }(); + shift_right_down = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + }(); + mode = mode_; + } + + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) + { + if constexpr(RowMajor) + { + // at least 3 instructions per row + index_t current_zero_point = + mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; + + // for every threads, most of the pixels are along the row, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(col_idx + shift_left_up), + 0)); + pixel += slope * position; + } + else + { + // at least 3 instructions per col; + index_t current_zero_point = mode == AlibiMode::VERTICAL + ? row_idx + col_idx + shift_right_down + : col_idx + shift_right_down; + + // for every threads, most of the pixels are along the col, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(row_idx + shift_left_up), + 0)); + pixel += slope * position; + } + } + + DataType slope; // float? + index_t shift_left_up; // always possitive + index_t shift_right_down; // always possitive + AlibiMode mode; +}; + +template +struct EmptyPositionEncoding +{ + CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) + { + } +}; + +// +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +template +CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, + index_t window_left_size, + index_t window_right_size, + index_t y_total, + index_t x_total, + GenericAttentionMaskEnum mask_enum) +{ + // assume mask_enum will never be NO_MASK, since if we do not have mask, it's + // totally OK to use constexpr + bool is_causal = window_left_size < 0 && window_right_size == 0; + AlibiMode alibi_mode = + is_causal ? AlibiMode::VERTICAL + : static_cast(mask_enum) /*either top-left or bottom-right*/; + return Alibi{slope, y_total, x_total, alibi_mode}; +} + +// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +// Do we need a device version? +template +CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) +{ + auto get_slopes_power_of_2 = [](ck_tile::index_t n) { + float start = std::powf( + static_cast(2), + -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); + + std::vector rtn; + for(auto i = 0; i < n; i++) + { + rtn.push_back(static_cast(start * std::powf(start, i))); + } + return rtn; + }; + if(is_power_of_two_integer(nheads)) + { + // power of 2 calculation + return get_slopes_power_of_2(nheads); + } + else + { + ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); + auto v0 = get_slopes_power_of_2(closest_power_of_2); + auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); + auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { + std::vector sliced; + for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) + { + if(i % 2 == 0) + sliced.push_back(vec[i]); + } + std::vector sliced_2(sliced.begin(), sliced.begin() + rem); + return sliced_2; + }(v1, nheads - closest_power_of_2); + v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); + return v0; + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e713cefbdaee1c329fbb17c2531e00befb6b4f44 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -0,0 +1,1421 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q] +// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v] +// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v]) +// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q]) +// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k] +// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1] +// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1] + +namespace ck_tile { + +template +struct FmhaBwdDQDKDVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using KGradEpiloguePipeline = ck_tile::remove_cvref_t; + using VGradEpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using GemmDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + using QGradDataType = ck_tile::remove_cvref_t; + using KGradDataType = ck_tile::remove_cvref_t; + using VGradDataType = ck_tile::remove_cvref_t; + using BiasGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaBwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lse_ptr; + const void* do_ptr; + const void* d_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t num_head_q; + ck_tile::index_t nhead_ratio_qk; + float raw_scale; +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale; +#endif + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + + ck_tile::index_t batch_stride_lsed; + }; + + struct FmhaBwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaBwdBatchModeBiasKargs : FmhaBwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaBwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaBwdCommonBiasGradKargs + { + void* dbias_ptr = nullptr; + ck_tile::index_t stride_dbias = 0; + ck_tile::index_t nhead_stride_dbias = 0; + }; + + struct FmhaBwdBatchModeBiasGradKargs : FmhaBwdCommonBiasGradKargs + { + ck_tile::index_t batch_stride_dbias = 0; + }; + + struct FmhaBwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaBwdCommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset, + const float raw_scale) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + scale_rp_undrop = rp_undrop * raw_scale; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaBwdBatchModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + }; + + struct FmhaBwdGroupModeKargs + : FmhaBwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t batch_stride_dk, + ck_tile::index_t batch_stride_dv, + ck_tile::index_t batch_stride_dbias, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_do, + batch_stride_dk, + batch_stride_dv}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + kargs.batch_stride_dbias = batch_stride_dbias; + } + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + const void* lse_ptr, + const void* do_ptr, + const void* d_ptr, + void* rand_val_ptr, + void* dq_ptr, + void* dk_ptr, + void* dv_ptr, + void* dbias_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_do, + ck_tile::index_t stride_dk, + ck_tile::index_t stride_dv, + ck_tile::index_t stride_dbias, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_lsed, + ck_tile::index_t nhead_stride_dbias, + ck_tile::index_t batch_stride_lsed, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_ptr, + do_ptr, + d_ptr, + dq_ptr, + dk_ptr, + dv_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale * ck_tile::log2e_v<>), +#endif + stride_q, + stride_k, + stride_v, + stride_do, + stride_dk, + stride_dv, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_do, + nhead_stride_lsed, + batch_stride_lsed}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for dbias + {}, // placeholder for mask + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasBiasGrad) + { + kargs.dbias_ptr = dbias_ptr; + kargs.stride_dbias = stride_dbias; + kargs.nhead_stride_dbias = nhead_stride_dbias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset, scale); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), + KGradEpiloguePipeline::GetSmemSize(), + VGradEpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); + + const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_lsed = 0; + long_index_t batch_offset_dk = 0; + long_index_t batch_offset_dv = 0; + long_index_t batch_offset_dbias = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = key_start * kargs.stride_dk; + batch_offset_dv = key_start * kargs.stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = query_start * kargs.stride_dbias; + } + else + { + batch_offset_dbias = key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_k <= i_n0) + { + return; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; + batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; + batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasBiasGrad) + { + batch_offset_dbias = static_cast(i_batch) * kargs.batch_stride_dbias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + const LSEDataType* lse_ptr = reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + QGradDataType* dq_ptr = reinterpret_cast(kargs.dq_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + KGradDataType* dk_ptr = reinterpret_cast(kargs.dk_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_dk; + VGradDataType* dv_ptr = reinterpret_cast(kargs.dv_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_dv; + + // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto qt_dram_naive = + transform_tensor_view(q_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto qt_dram = [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + qt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto kt_dram_naive = + transform_tensor_view(k_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_q), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto kt_dram = [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + kt_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kVLoadOnce) + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + lse_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + const auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + const auto do_dram = [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + const auto dot_dram_naive = + transform_tensor_view(do_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_q)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + const auto dot_dram = [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + dot_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto dq_dram = [&]() { + const auto dq_dram_naive = make_naive_tensor_view( + dq_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dq_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto qt_dram_window = + make_tile_window(qt_dram, + [&]() { + if constexpr(FmhaPipeline::kQTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + [&]() { + if constexpr(FmhaPipeline::kKLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto kt_dram_window = + make_tile_window(kt_dram, + [&]() { + if constexpr(FmhaPipeline::kKTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, i_n0}); + + auto v_dram_window = make_tile_window( + v_dram, + [&]() { + if constexpr(FmhaPipeline::kVLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_n0, 0}); + + auto do_dram_window = make_tile_window( + do_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {0, 0}); + + auto dot_dram_window = + make_tile_window(dot_dram, + [&]() { + if constexpr(FmhaPipeline::kOGradTLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {0, 0}); + + auto dq_dram_window = make_tile_window( + dq_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto lse_dram_window = + make_tile_window(lse_dram, make_tuple(number{}), {0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {0}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dbias_dram_window = [&, i_nhead_ = i_nhead]() { + if constexpr(kHasBiasGrad) + { + BiasGradDataType* dbias_ptr = + reinterpret_cast(kargs.dbias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_dbias + + batch_offset_dbias; + + auto dbias_dram = [&]() { + const auto dbias_dram_naive = + make_naive_tensor_view( + dbias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_dbias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(dbias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + AccDataType slope = *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + // dropout + float rp_undrop = 1; + float scale_rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + scale_rp_undrop = kargs.scale_rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, + qt_dram_window, + k_dram_window, + kt_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + do_dram_window, + dot_dram_window, + lse_dram_window, + d_dram_window, + dq_dram_window, + dbias_dram_window, + mask, + position_encoding, + kargs.raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + kargs.scale, +#endif + rp_undrop, + scale_rp_undrop, + smem_ptr, + dropout); + + auto dk_dram = [&]() { + const auto dk_dram_naive = make_naive_tensor_view( + dk_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_dk, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dk_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dv_dram = [&]() { + const auto dv_dram_naive = make_naive_tensor_view( + dv_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_dv, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + dv_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto dk_dram_window = make_tile_window( + dk_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + auto dv_dram_window = make_tile_window( + dv_dram, + make_tuple(number{}, number{}), + {i_n0, 0}); + + KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); + VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile); + } +}; + +template +struct FmhaBwdOGradDotOKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaBwdOGradDotO = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; + static constexpr ck_tile::index_t kM0 = kBlockSize; + static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim; + + using DDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using OGradDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaBwdOGradDotOCommonKargs + { + const void* o_ptr; + const void* do_ptr; + void* d_ptr; + + float p_undrop; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t stride_do; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_d; + ck_tile::index_t batch_stride_d; + }; + + struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs + { + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_do, + ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + seqlen_q, + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + batch_stride_do, + batch_stride_o}; + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* o_ptr, + const void* do_ptr, + void* d_ptr, + float p_undrop, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t stride_do, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_do, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_d, + ck_tile::index_t batch_stride_d) + { + Kargs kargs{{o_ptr, + do_ptr, + d_ptr, + p_undrop, + -1, // seqlen will be updated by another pointer + hdim_v, + stride_do, + stride_o, + nhead_stride_do, + nhead_stride_o, + nhead_stride_d, + batch_stride_d}, + reinterpret_cast(seqstart_q_ptr)}; + + return kargs; + } + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // divide problem + const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + + long_index_t batch_offset_o = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_d = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.stride_o; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + } + + // for simplicity, batch stride we just modify the pointer + const ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_do + + batch_offset_do; + DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_d + + batch_offset_d; + + // O/dO/D DRAM and DRAM window + const auto o_dram = [&]() { + auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto do_dram = [&]() { + auto do_dram_naive = make_naive_tensor_view( + do_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_do, 1), + number{}, + number<1>{}); + return pad_tensor_view(do_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + auto d_dram = [&]() { + const auto d_dram_naive = make_naive_tensor_view_packed( + d_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + d_dram_naive, make_tuple(number{}), sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto do_dram_window = + make_tile_window(do_dram, make_tuple(number{}, number{}), {i_m0, 0}); + + auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {i_m0}); + + FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc875b8e5a3e7b9ead8a721c370375ed918cf48c --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaBwdTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +template +struct FmhaBwdOGradDotOTilePartitioner +{ + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/) + { + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0732fd2ce2f72bf19501046d9a42db6ae756a561..5ecc3a4d8022857fda8bdbbe99239b93cedd33fd 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1,18 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include #include -// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] * K[seqlen_k, hdim_q] +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] -// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k]) -// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] namespace ck_tile { @@ -31,8 +32,11 @@ struct FmhaFwdKernel using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; @@ -41,8 +45,9 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -74,14 +79,15 @@ struct FmhaFwdKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -108,6 +114,7 @@ struct FmhaFwdKernel ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; + ck_tile::index_t num_head_q; // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k // if this param is larger than 1, indicate MQA/GQA case ck_tile::index_t nhead_ratio_qk; @@ -136,6 +143,13 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_bias = 0; }; + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -153,19 +167,48 @@ struct FmhaFwdKernel { void* lse_ptr = nullptr; ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; }; - struct FmhaFwdBatchModeLSEKargs : FmhaFwdCommonLSEKargs + struct FmhaFwdCommonDropoutKargs { - ck_tile::index_t batch_stride_lse = 0; + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; }; struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -175,10 +218,15 @@ struct FmhaFwdKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -193,12 +241,14 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -207,22 +257,28 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -232,6 +288,7 @@ struct FmhaFwdKernel seqlen_k, hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -250,18 +307,24 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_o}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -279,6 +342,15 @@ struct FmhaFwdKernel kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -289,6 +361,7 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + void* rand_val_ptr, void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, @@ -296,6 +369,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, @@ -304,16 +378,22 @@ struct FmhaFwdKernel ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) { Kargs kargs{{q_ptr, k_ptr, @@ -323,6 +403,7 @@ struct FmhaFwdKernel -1, // hdim_q, hdim_v, + num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), @@ -341,16 +422,22 @@ struct FmhaFwdKernel {}, // placeholder for mask {}, // placeholder for lse {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -361,12 +448,21 @@ struct FmhaFwdKernel { kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } return kargs; } @@ -398,12 +494,13 @@ struct FmhaFwdKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -421,17 +518,17 @@ struct FmhaFwdKernel { batch_offset_v = key_start; } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; } - else + if constexpr(kStoreLSE) { - batch_offset_bias = key_start; + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } - if constexpr(kStoreLSE) + if constexpr(kHasDropout) { - batch_offset_lse = query_start; + batch_offset_randval = query_start * kargs.stride_randval; } batch_offset_o = query_start * kargs.stride_o; @@ -461,7 +558,7 @@ struct FmhaFwdKernel batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } @@ -469,6 +566,11 @@ struct FmhaFwdKernel { batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -585,7 +687,7 @@ struct FmhaFwdKernel const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto bias_dram_window_lengths = make_tuple(number{}, number{}); - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { const BiasDataType* bias_ptr = reinterpret_cast(kargs.bias_ptr) + @@ -642,6 +744,56 @@ struct FmhaFwdKernel } }(); + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -654,6 +806,39 @@ struct FmhaFwdKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { @@ -666,14 +851,17 @@ struct FmhaFwdKernel identity{}, // v_element_func bias_dram_window, identity{}, // bias_element_func + randval_dram_window, lse_dram_window, identity{}, // lse_element_func identity{}, // s_acc_element_func scales{kargs.scale_p}, // p_compute_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func mask, + position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } else { @@ -681,10 +869,13 @@ struct FmhaFwdKernel k_dram_window, v_dram_window, bias_dram_window, + randval_dram_window, lse_dram_window, mask, + position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + dropout); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f4313d5b690db53b4fd78a4375ece2e88922a30 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -0,0 +1,455 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineKernel +{ + using TilePartitioner = remove_cvref_t; + using FmhaPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(FmhaPipeline::kM0) + "x" + + _TS_(FmhaPipeline::kN1) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + + _SS_(FmhaPipeline::name) + + (pn.empty() ? "" : "_" + pn) + + (kStoreLSE ? "_lse" : "" ) + + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* lse_acc_ptr; + const void* o_acc_ptr; + void* o_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t hdim_v; + ck_tile::index_t num_splits; + + ck_tile::index_t row_stride_o_acc; + ck_tile::index_t row_stride_o; + + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct Fp8StaticQuantKargs + { + float scale_o; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_o; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + seqlen_q, + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + batch_stride_o}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* lse_acc_ptr, + const void* o_acc_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits, + float scale_o, + ck_tile::index_t row_stride_o_acc, + ck_tile::index_t row_stride_o, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc) + { + Kargs kargs{{lse_acc_ptr, + o_acc_ptr, + o_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + hdim_v, + num_splits, + row_stride_o_acc, + row_stride_o, + nhead_stride_lse_acc, + nhead_stride_o_acc, + nhead_stride_o, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + reinterpret_cast(seqstart_q_ptr)}; + + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_o = scale_o; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_o = query_start * kargs.row_stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + } + else + { + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc; + const OaccDataType* o_acc_ptr = + reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // LSEacc/Oacc DRAM and DRAM windows + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = make_naive_tensor_view( + lse_acc_ptr, + make_tuple(kargs.num_splits, kargs.seqlen_q), + make_tuple(kargs.split_stride_lse_acc, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), + make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), + number{}, + number<1>{}); + + auto o_acc_dram_view = pad_tensor_view( + o_acc_dram_naive, + make_tuple(number<1>{}, number{}, number{}), + sequence{}); + + const index_t padded_max_seqlen_q = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; + const index_t padded_hdim_v = + o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; + + return transform_tensor_view( + o_acc_dram_view, + make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), + make_pass_through_transform(padded_hdim_v)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }(); + + auto lse_acc_dram_window = make_tile_window( + lse_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {0, i_m0}); + + auto o_acc_dram_window = make_tile_window( + o_acc_dram, + [&]() { + return make_tuple(number{}, number{}); + }(), + {i_m0, i_n1}); + + // LSE DRAM window + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + identity{}, // lse_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + else + { + return FmhaPipeline{}(lse_acc_dram_window, + o_acc_dram_window, + lse_dram_window, + kargs.num_splits, + kargs.max_seqlen_q, + smem_ptr); + } + }(); + + // O DRAM and DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.row_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f04843a391dd50630a99a44769884448793e4d2 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVCombineTilePartitioner +{ + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN1 = kN1_; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..45ed185adaa454dd2be368f4553294ae7814dff5 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -0,0 +1,913 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdSplitKVKernel +{ + using TilePartitioner = ck_tile::remove_cvref_t; + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using OaccDataType = remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + __host__ static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using gbr = typename bfs::Gemm0BlockWarps; + using gwt = typename bfs::Gemm0WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct EmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct CommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* lse_acc_ptr; + void* o_acc_ptr; + + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + ck_tile::index_t num_splits; + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o_acc; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + }; + + struct CommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct BatchModeBiasKargs : CommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct AlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct MaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct Fp8StaticQuantKargs + { + float scale_p; + }; + + struct CommonDropoutKargs + { + void init_dropout(const float p_drop, + const std::tuple& drop_seed_offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + drop_seed = std::get<0>(drop_seed_offset); + drop_offset = std::get<1>(drop_seed_offset); + } + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + uint64_t drop_seed = 1; + uint64_t drop_offset = 0; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + struct BatchModeDropoutKargs : CommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct BatchModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + }; + + struct GroupModeKargs + : CommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + template + __host__ static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_acc_ptr, + void* o_acc_ptr, + ck_tile::index_t batch, + ck_tile::index_t max_seqlen_q, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + ck_tile::index_t num_splits, + float scale_s, + float scale_p, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o_acc, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse_acc, + ck_tile::index_t nhead_stride_o_acc, + ck_tile::index_t batch_stride_lse_acc, + ck_tile::index_t batch_stride_o_acc, + ck_tile::index_t split_stride_lse_acc, + ck_tile::index_t split_stride_o_acc, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lse_acc_ptr, + o_acc_ptr, + batch, + max_seqlen_q, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_splits, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o_acc, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_lse_acc, + nhead_stride_o_acc, + batch_stride_lse_acc, + batch_stride_o_acc, + split_stride_lse_acc, + split_stride_o_acc}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + } + if constexpr(kHasDropout) + { + kargs.init_dropout(p_drop, drop_seed_offset); + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); + } + + __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = + TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + const long_index_t batch_offset_lse_acc = + static_cast(i_batch) * kargs.batch_stride_lse_acc; + const long_index_t batch_offset_o_acc = + static_cast(i_batch) * kargs.batch_stride_o_acc; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + OaccDataType* o_acc_ptr = reinterpret_cast(kargs.o_acc_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o_acc + + batch_offset_o_acc + i_split * kargs.split_stride_o_acc; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse acc + auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() { + constexpr auto lse_acc_dram_window_lengths = make_tuple(number{}); + LSEDataType* lse_acc_ptr = + reinterpret_cast(kargs.lse_acc_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse_acc + + batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc; + + const auto lse_acc_dram = [&]() { + const auto lse_acc_dram_naive = + make_naive_tensor_view(lse_acc_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); + }(); + + // dropout + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + uint64_t drop_seed = 0; + uint64_t drop_offset = 0; + bool is_store_randval = false; + + if constexpr(kHasDropout) + { + rp_undrop = kargs.rp_undrop; + p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t; + drop_seed = kargs.drop_seed; + drop_offset = kargs.drop_offset; + is_store_randval = kargs.is_store_randval; + } + BlockDropout dropout(i_batch, + i_nhead, + kargs.num_head_q, + drop_seed, + drop_offset, + rp_undrop, + p_undrop_in_uint8_t, + is_store_randval); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + auto o_acc_tile = [&, i_split_ = i_split]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_acc_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + identity{}, // o_acc_element_func + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_acc_dram_window, + kargs.num_splits, + i_split_, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + }(); + + // Oacc DRAM and Oacc DRAM window + auto o_acc_dram = [&]() { + const auto o_acc_dram_naive = make_naive_tensor_view( + o_acc_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.hdim_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_acc_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_acc_dram_window = + make_tile_window(o_acc_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_acc_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..aec37cb36f699ae560cbe598ccf674569f80b29b --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct FmhaFwdSplitKVTilePartitioner +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + __host__ static constexpr auto GridSize(ck_tile::index_t batch_size, + ck_tile::index_t nhead, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_splits) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * + ck_tile::integer_divide_ceil(hdim_v, kN1), + nhead * num_splits, + batch_size); + } + + CK_TILE_DEVICE auto + operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1); + const auto [i_nhead, i_split] = f(blockIdx.y, num_splits); + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index 52f458c72e5a2df58a57c3903904c76b73880aa1..2dca84b78645103fbbdcd30b5417381a8954ab8b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; - __host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, - ck_tile::index_t nhead_, - ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_) + static constexpr const char* name = "shb"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) { // TODO: this may need tuning return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * @@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner } }; +template +using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner; + +template +struct FmhaFwdTilePartitioner_HBS +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + static constexpr const char* name = "hbs"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1)); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f1899370387c8738a902a058650b60aef589175b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdOGradDotO +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kVHeaddim = Problem::kVHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } + + template + CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + DDramBlockWindowTmp& d_dram_block_window_tmp, + float p_undrop) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == + OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kBlockSize == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + + auto o_dram_window = + make_tile_window(o_dram_block_window_tmp.get_bottom_tensor_view(), + o_dram_block_window_tmp.get_window_lengths(), + o_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreODramTileDistribution()); + + auto o = load_tile(o_dram_window); + + auto do_dram_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + do_dram_block_window_tmp.get_window_origin(), + Policy::template MakePreOGradDramTileDistribution()); + + auto do_ = load_tile(do_dram_window); + + // declare d + constexpr auto d_dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); + + auto d = make_static_distributed_tensor(d_dstr); + + clear_tile(d); // Initialize D + + constexpr auto o_spans = decltype(o)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + d(i_idx) += + (type_convert(o[i_j_idx]) * type_convert(do_[i_j_idx])); + }); + }); + + tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d); + + store_tile(d_dram_block_window_tmp, d); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7843ab33a1f16a15babeb220d5ee4217d994bad6 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// These templates are not used here. +using BlockFmhaBwdOGradDotODefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3444567508a5ad3a3a3e97cf86140d5be09400f1 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -0,0 +1,848 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSKTSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = true; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_kts_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& kt_dram_block_window_tmp, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + KDataType* kt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto kt_lds = make_tensor_view( + kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeKT())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto kt_dram_block_window = kt_dram_block_window_tmp; + + auto kt_dram_window = make_tile_window( + kt_dram_block_window.get_bottom_tensor_view(), + kt_dram_block_window.get_window_lengths(), + kt_dram_block_window.get_window_origin(), + Policy::template MakeKTDramTileDistribution()); // K^T DRAM tile window for + // load + + auto kt_block_tile = load_tile(kt_dram_window); + + auto kt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledKTRegBlockDescriptor()); + shuffle_tile(kt_shuffle_tmp, kt_block_tile); + + store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a05fbf252fdee0c189b34e0eaa679ca6ce3f01bb --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k & k^t located in lds. +using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dec421c1e6ee6e803733e8ca559805ff34129797 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -0,0 +1,821 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineKSVR +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = false; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = false; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "ks_vr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& qt_dram_block_window_tmp, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kVHeaddim == + OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + QDataType* qt_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto qt_lds = make_tensor_view( + qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + OGradDataType* dot_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto dot_lds = make_tensor_view( + dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto qt_dram_block_window = + make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), + qt_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dot_dram_block_window = + make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), + dot_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_q_start}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto qt_dram_window = + make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), + qt_dram_block_window.get_window_lengths(), + qt_dram_block_window.get_window_origin(), + Policy::template MakeQTDramTileDistribution()); + + auto dot_dram_window = + make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), + dot_dram_block_window.get_window_lengths(), + dot_dram_block_window.get_window_origin(), + Policy::template MakeOGradTDramTileDistribution()); + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + { + move_tile_window(q_dram_window, {0, kK0}); + + clear_tile(st_acc); // Initialize S^T + + store_tile(q_lds_window, q_block_tile); // LDS write 0 + q_block_tile = load_tile(q_dram_window); // global read 1 + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + move_tile_window(q_dram_window, {0, kK0}); + + store_tile(q_lds_window, + q_block_tile); // LDS write i + 1 + q_block_tile = load_tile(q_dram_window); // global read i + 2 + }); + } + + const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{})); + block_sync_lds(); + + store_tile(q_lds_window, q_block_tile); + block_sync_lds(); + + gemm_0(st_acc, + q_lds_window, + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + auto dot_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledOGradTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(dot_shuffle_tmp, dot_prefetch); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + } + move_tile_window(dot_dram_window, {0, kK1}); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto dot = load_tile(dot_dram_window); // load next OGrad^T + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile(pt_gemm, + sequence{}, + sequence<(i_k1 + 1) * kK1, kN0>{}), + dot_lds_window); + block_sync_lds(); + shuffle_tile(dot_shuffle_tmp, dot); + store_tile(dot_lds_window, + dot_shuffle_tmp); // store the prefetch + + move_tile_window(dot_dram_window, {0, kK1}); + }); + } + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + // tail + { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence{}), + dot_lds_window); + block_sync_lds(); + } + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + + { + move_tile_window(do_dram_window, {0, kK2}); + + clear_tile(dpt_acc); // Initialize PGrad^T + + store_tile(do_lds_window, do_block_tile); // LDS write 0 + do_block_tile = load_tile(do_dram_window); // global read 1 + } + + if constexpr(k2_loops > 2) + { + static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + move_tile_window(do_dram_window, {0, kK2}); + + store_tile(do_lds_window, + do_block_tile); // LDS write i + 1 + do_block_tile = load_tile(do_dram_window); // global read i + 2 + }); + } + + const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile + { // tail + block_sync_lds(); + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 2) * kK2>{}, + sequence{})); + block_sync_lds(); + + store_tile(do_lds_window, do_block_tile); + block_sync_lds(); + + gemm_2(dpt_acc, + do_lds_window, + get_slice_tile(v, + sequence<0, (k2_loops - 1) * kK2>{}, + sequence{})); + } + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + auto qt_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledQTRegBlockDescriptor()); + block_sync_lds(); + { + shuffle_tile(qt_shuffle_tmp, qt_prefetch); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + } + move_tile_window(qt_dram_window, {0, kK3}); + + const auto dst_gemm = cast_tile(dst); + + if constexpr(k3_loops > 1) + { + static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { + const auto qt = load_tile(qt_dram_window); // load next Q^T + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile(dst_gemm, + sequence{}, + sequence<(i_k3 + 1) * kK3, kN0>{}), + qt_lds_window); + block_sync_lds(); + shuffle_tile(qt_shuffle_tmp, qt); + store_tile(qt_lds_window, + qt_shuffle_tmp); // store the prefetch + + move_tile_window(qt_dram_window, {0, kK3}); + }); + } + // tail + { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence{}), + qt_lds_window); + block_sync_lds(); + } + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc4e6304d0fd8ad0b6714e8fddf841fc2f5452b2 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, k located in lds. +using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp new file mode 100644 index 0000000000000000000000000000000000000000..97487bb71e22ddf50dd08b8bda3c55a20ba1859b --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -0,0 +1,692 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + + static constexpr bool kQLoadOnce = true; + static constexpr bool kQTLoadOnce = false; + static constexpr bool kKLoadOnce = true; + static constexpr bool kKTLoadOnce = false; + static constexpr bool kVLoadOnce = true; + static constexpr bool kOGradLoadOnce = true; + static constexpr bool kOGradTLoadOnce = false; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentOGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + static constexpr index_t kAlignmentQGrad = + kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad(); + static constexpr index_t kAlignmentKGrad = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + static constexpr index_t kAlignmentVGrad = + kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias(); + + static constexpr const char* name = "qs_ks_vr_dos"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, + const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, + const KDramBlockWindowTmp& k_dram_block_window_tmp, + const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, + const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, + FmhaMask mask, + PositionEncoding position_encoding, + float raw_scale, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + float scale, +#endif + float rp_undrop, + float scale_rp_undrop, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // Q tile in LDS + QDataType* q_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK())); + auto q_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); + auto q_lds_window = + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + + // QT tile in LDS + auto qt_lds = make_tensor_view( + q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT()); + auto qt_lds_window = + make_tile_window(qt_lds, make_tuple(number{}, number{}), {0, 0}); + + // K tile in LDS + auto k_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // KT tile in LDS + auto kt_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeKLdsBlockDescriptorAsKT()); + auto kt_lds_window = + make_tile_window(kt_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGrad tile in LDS + OGradDataType* do_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ())); + auto do_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); + auto do_lds_window = + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + + // OGradT tile in LDS + auto dot_lds = make_tensor_view( + do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT()); + auto dot_lds_window = + make_tile_window(dot_lds, make_tuple(number{}, number{}), {0, 0}); + + // SGrad tile in LDS + GemmDataType* ds_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto ds_lds = make_tensor_view( + ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor()); + auto ds_lds_window = + make_tile_window(ds_lds, make_tuple(number{}, number{}), {0, 0}); + + // BiasT/BiasGradT tile in LDS, use the same size and layout + BiasDataType* biast_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeK() + + Policy::template GetSmemSizeQ() + + Policy::template GetSmemSizeOGrad())); + auto biast_lds = make_tensor_view( + biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor()); + auto biast_lds_shuffle_window = + make_tile_window(biast_lds, make_tuple(number{}, number{}), {0, 0}); + auto dbiast_lds_shuffle_window = + make_tile_window(biast_lds, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeShuffledBiasTileDistribution()); + + static_assert(std::is_same_v, + "BiasDataType and BiasGradDataType should be the same!"); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); + constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm(); + constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm(); + constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm(); + + auto v_dram_window = make_tile_window( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + v_dram_block_window_tmp.get_window_origin(), + Policy::template MakeVInRegDramTileDistribution()); + + auto v = load_tile(v_dram_window); // persistent V register tile + + using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile()); + using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile()); + using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile()); + + // init VGrad & KGrad + auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; + auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; + + clear_tile(dv_acc); + clear_tile(dk_acc); + + auto k_dram_window = make_tile_window( + k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + k_dram_block_window_tmp.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + __builtin_amdgcn_sched_barrier(0); + const auto k_origin = k_dram_window.get_window_origin(); + const auto [seqlen_q_start, seqlen_q_end] = + mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking) + { + if(num_total_loop <= 0) + { + // Note: here dk_acc&dv_acc are all cleard, return it + // Note: v loaded but no fence, ignore it. + return ck_tile::make_tuple(dk_acc, dv_acc); + } + } + + auto k_block_tile = load_tile(k_dram_window); + + store_tile(k_lds_window, k_block_tile); // // persistent K in LDS + + auto q_dram_block_window = + make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto do_dram_block_window = + make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), + do_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto dq_dram_block_window = + make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), + dq_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, 0}); + + auto lse_dram_block_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + auto d_dram_block_window = + make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), + d_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_block_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, bias_origin.at(number<1>{})}); // M/N + + const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); + auto dbias_dram_block_window = + make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), + dbias_dram_block_window_tmp.get_window_lengths(), + {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N + + auto lse_dram_window = make_tile_window( + lse_dram_block_window.get_bottom_tensor_view(), + lse_dram_block_window.get_window_lengths(), + lse_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto d_dram_window = make_tile_window( + d_dram_block_window.get_bottom_tensor_view(), + d_dram_block_window.get_window_lengths(), + d_dram_block_window.get_window_origin(), + Policy::template MakeLSEDDramTileDistribution()); + + auto bias_dram_window = + make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), + bias_dram_block_window.get_window_lengths(), + bias_dram_block_window.get_window_origin(), + Policy::template MakeBiasTileDistribution()); + + auto biast_lds_window = + make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), + biast_lds_shuffle_window.get_window_lengths(), + biast_lds_shuffle_window.get_window_origin(), + Policy::template MakeBiasTTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_q_start); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kM0 / kK1; + constexpr index_t k2_loops = kVHeaddim / kK2; + constexpr index_t k3_loops = kM0 / kK3; + constexpr index_t k4_loops = kN0 / kK4; + do + { + auto q_dram_window = make_tile_window( + q_dram_block_window.get_bottom_tensor_view(), + q_dram_block_window.get_window_lengths(), + q_dram_block_window.get_window_origin(), + Policy::template MakeQDramTileDistribution()); // Q DRAM tile window for + // load + + auto do_dram_window = make_tile_window( + do_dram_block_window.get_bottom_tensor_view(), + do_dram_block_window.get_window_lengths(), + do_dram_block_window.get_window_origin(), + Policy::template MakeOGradDramTileDistribution()); // OGrad DRAM tile + // window for load + + // STAGE 1, Q@K Gemm0 + auto st_acc = SPTBlockTileType{}; + + auto q_block_tile = load_tile(q_dram_window); + clear_tile(st_acc); // Initialize S^T + store_tile(q_lds_window, q_block_tile); // LDS write + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, i_k0 * kK0>{}, + sequence{})); + block_sync_lds(); + }); + } + + auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile + { // tail + block_sync_lds(); + gemm_0(st_acc, + get_slice_tile(q_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(k_lds_window, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + block_sync_lds(); + } + + // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + block_sync_lds(); + auto bias_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBiasTileDistribution()); + shuffle_tile(bias_shuffle_tmp, bias_tile); + store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); + block_sync_lds(); + auto biast_tile = load_tile(biast_lds_window); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x = raw_scale * x + type_convert(y); +#else + x = scale * x + log2e_v * type_convert(y); +#endif + }, + st_acc, + biast_tile); + move_tile_window(bias_dram_window, {kM0, 0}); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc); +#endif + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if(st_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto lse = load_tile(lse_dram_window); + + static const auto get_validated_lse = [](LSEDataType raw_lse) { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_lse == -numeric::infinity() + ? type_convert(0.f) + : raw_lse; + } + else + { + return raw_lse; + } + }; + + auto pt = SPTBlockTileType{}; + constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); + sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_lse = log2e_v * get_validated_lse(lse[i_idx]); +#endif + sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); + } + else + { + pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); + } +#else + pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx])); +#endif + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); + } + + // STAGE 3, P^T@OGrad^T Gemm1 + block_sync_lds(); + store_tile(do_lds_window, do_block_tile); // store the prefetch + + const auto pt_gemm = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [](const auto& x) { return type_convert(x > 0.f ? x : 0.f); }, + pt); + } + else + { + return cast_tile(pt); + } + }(); + + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + block_sync_lds(); + gemm_1(dv_acc, + get_slice_tile( + pt_gemm, sequence{}, sequence<(i_k1 + 1) * kK1, kN0>{}), + get_slice_tile(dot_lds_window, + sequence<0, i_k1 * kK1>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 4, OGrad@V Gemm2 + auto dpt_acc = SPGradTBlockTileType{}; + clear_tile(dpt_acc); // Initialize PGrad^T + + static_for<0, k2_loops, 1>{}([&](auto i_k2) { + block_sync_lds(); + gemm_2(dpt_acc, + get_slice_tile(do_lds_window, + sequence<0, i_k2 * kK2>{}, + sequence{}), + get_slice_tile( + v, sequence<0, i_k2 * kK2>{}, sequence{})); + block_sync_lds(); + }); + + // STAGE 5, P^T(PGrad^T - D) + const auto d = load_tile(d_dram_window); + + auto dst = SPGradTBlockTileType{}; + constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); + sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + bool undrop_flag = pt[i_j_idx] >= 0; + dst(i_j_idx) = + pt[i_j_idx] * + (!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); + }); + }); + + if constexpr(kHasBiasGrad) + { + const auto dbiast = [&]() { + if constexpr(kHasDropout) + { + return tile_elementwise_in( + [&rp_undrop](const auto& x) { + return type_convert(x * rp_undrop); + }, + dst); + } + else + { + return cast_tile(dst); + } + }(); + store_tile(biast_lds_shuffle_window, dbiast); + block_sync_lds(); + auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); + auto dbiast_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeBiasTileDistribution()); + shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); + store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); + move_tile_window(dbias_dram_block_window, {kM0, 0}); + } + + // STAGE 6, SGrad^T@Q^T Gemm3 + block_sync_lds(); + const auto dst_gemm = cast_tile(dst); + + static_for<0, k3_loops, 1>{}([&](auto i_k3) { + block_sync_lds(); + gemm_3(dk_acc, + get_slice_tile( + dst_gemm, sequence{}, sequence<(i_k3 + 1) * kK3, kN0>{}), + get_slice_tile(qt_lds_window, + sequence<0, i_k3 * kK3>{}, + sequence{})); + block_sync_lds(); + }); + + // STAGE 7, SGrad@K^T Gemm4 + store_tile(ds_lds_window, dst_gemm); + + auto dq_acc = QGradBlockTileType{}; + clear_tile(dq_acc); // Initialize QGrad + + block_sync_lds(); + + static_for<0, k4_loops, 1>{}([&](auto i_k4) { + gemm_4(dq_acc, + get_slice_tile(ds_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{}), + get_slice_tile(kt_lds_window, + sequence<0, i_k4 * kK4>{}, + sequence{})); + }); + + // QGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dq_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); + } + const auto dq = cast_tile(dq_acc); + update_tile(dq_dram_block_window, dq); + + // move tile windows + move_tile_window(q_dram_block_window, {kM0, 0}); + move_tile_window(dq_dram_block_window, {kM0, 0}); + move_tile_window(do_dram_block_window, {kM0, 0}); + move_tile_window(lse_dram_window, {kM0}); + move_tile_window(d_dram_window, {kM0}); + } while(++i_total_loops < num_total_loop); + + // KGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, + dk_acc); + } + else + { + tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); + } + // VGrad Scale + if constexpr(kHasDropout) + { + tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); + } + + return ck_tile::make_tuple(dk_acc, dv_acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac81990e0753359fee630794044fd08d93121058 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +// This pipeline is v located in regs, q & k & do located in lds. +using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy = + BlockFmhaBwdPipelineDefaultPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d867772a1f77692e0aecdd1bf948f09596c1f214 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -0,0 +1,1343 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineDefaultPolicy +{ + static constexpr bool QLoadOnce = + QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once + static constexpr bool QTLoadOnce = + QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once + static constexpr bool KLoadOnce = + KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once + static constexpr bool KTLoadOnce = + KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once + static constexpr bool VLoadOnce = + VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once + static constexpr bool OGradLoadOnce = + OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once + static constexpr bool OGradTLoadOnce = + OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once + + // these are for global load + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + if constexpr(VLoadOnce) + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; + } + else + { + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() + { + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentVGrad() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto vec = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number{}); + return vec; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 4) + return 4; + else + return 2; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + + // TODO: not correct! + if constexpr(total_pixels > 32) + return 8; + else + return 4; + } + + // these are for lds + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() + { + // TODO: this is for 3d layout + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() + { + // TODO: this is for 3d layout + using BiasDataType = remove_cvref_t; + return 16 / sizeof(BiasDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() + { + // TODO: this is for 3d layout + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() + { + // TODO: this is for 3d layout + using GemmDataType = remove_cvref_t; + return 16 / sizeof(GemmDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return x_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT() + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(MNPerBlock + 1) * KPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(MNPerBlock), + make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + static_assert(PixelsPerRow % KPack == 0); + constexpr index_t NPerRow = PixelsPerRow / KPack; + static_assert(MNPerBlock % NPerRow == 0); + static_assert(KPerBlock % KPack == 0); + + constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto xt_lds_block_desc = transform_tensor_descriptor( + xt_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return xt_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackQ(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + constexpr index_t kKPack = GetSmemKPackK(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPack = GetSmemKPackV(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + constexpr index_t kKPack = GetSmemKPackOGrad(); + + return MakeXLdsBlockDescriptorAsXT(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPack = GetSmemKPackSGrad(); + + return MakeXLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() + { + using QDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType); + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() + { + using KDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() + { + using OGradDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType); + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + return MakeXTLdsBlockDescriptor(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() + { + using BiasDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType); + constexpr index_t kKPack = GetSmemKPackBias(); + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kMPerBlock % kKPack == 0); + + constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto biast_lds_block_desc = transform_tensor_descriptor( + biast_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return biast_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() + { + constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * + MakeQLdsBlockDescriptor().get_element_space_size(); + return smem_size_q; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() + { + constexpr index_t smem_size_qt = [&]() { + if constexpr(QLoadOnce && !QTLoadOnce) + return 0; + else + return sizeof(typename Problem::QDataType) * + MakeQTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_qt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * + MakeKLdsBlockDescriptor().get_element_space_size(); + return smem_size_k; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() + { + constexpr index_t smem_size_kt = [&]() { + if constexpr(KLoadOnce && !KTLoadOnce) + return 0; + else + return sizeof(typename Problem::KDataType) * + MakeKTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_kt; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + constexpr index_t smem_size_v = [&]() { + if constexpr(VLoadOnce) + return 0; + else + return sizeof(typename Problem::VDataType) * + MakeVLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_v; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() + { + constexpr index_t smem_size_do = + sizeof(typename Problem::OGradDataType) * + MakeOGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_do; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() + { + constexpr index_t smem_size_dot = [&]() { + if constexpr(OGradLoadOnce && !OGradTLoadOnce) + return 0; + else + return sizeof(typename Problem::OGradDataType) * + MakeOGradTLdsBlockDescriptor().get_element_space_size(); + }(); + return smem_size_dot; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() + { + constexpr index_t smem_size_ds = + sizeof(typename Problem::GemmDataType) * + MakeSGradLdsBlockDescriptor().get_element_space_size(); + return smem_size_ds; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() + { + constexpr index_t smem_size_bias = [&]() { + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return sizeof(typename Problem::BiasDataType) * + MakeBiasTLdsBlockDescriptor().get_element_space_size(); + else + return 0; + }(); + return smem_size_bias; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_size_q = GetSmemSizeQ(); + constexpr index_t smem_size_qt = GetSmemSizeQT(); + constexpr index_t smem_size_k = GetSmemSizeK(); + constexpr index_t smem_size_kt = GetSmemSizeKT(); + constexpr index_t smem_size_v = GetSmemSizeV(); + constexpr index_t smem_size_do = GetSmemSizeOGrad(); + constexpr index_t smem_size_dot = GetSmemSizeOGradT(); + constexpr index_t smem_size_ds = GetSmemSizeSGrad(); + constexpr index_t smem_size_bias = GetSmemSizeBias(); + constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); + + index_t smem_size = 0; + + if constexpr(QLoadOnce && OGradLoadOnce) + smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + + smem_size_transpose; // 1~4 & 10 + else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += smem_size_q + smem_size_qt + + max(smem_size_do, + smem_size_dot, + smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) + smem_size += smem_size_do + smem_size_dot + + max(smem_size_q, + smem_size_qt, + smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy + else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) + smem_size += max(smem_size_q, + smem_size_qt, + smem_size_do, + smem_size_dot, + smem_size_transpose); // 9/13 TODO: Multiple buffers strategy + + // 14/15 needs to be adjusted + if constexpr(KLoadOnce) + smem_size += (smem_size_k + smem_size_kt); // 1~13 + else + smem_size = + max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy + + return max(smem_size, smem_size_v); // 15 + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() + { + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + + constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; + constexpr index_t N0 = NWarp; + + constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; + constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; + constexpr index_t M1 = MWarp; + constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple>, + tuple, sequence<1, 0>>, + tuple, sequence<3, 1>>, + sequence<1, 1, 1>, + sequence<0, 2, 4>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t K1 = 16 / sizeof(VDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KLoadOnce) + return Problem::BlockFmhaShape::kQKHeaddim; + else + return Problem::BlockFmhaShape::kK0; + }(); + + constexpr index_t K1 = GetAlignmentK(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradLoadOnce) + return Problem::BlockFmhaShape::kVHeaddim; + else + return Problem::BlockFmhaShape::kK2; + }(); + + constexpr index_t K1 = GetAlignmentOGrad(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() + { + constexpr index_t K1 = 16 / sizeof(DataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = 1; + constexpr index_t M1 = get_warp_size(); + constexpr index_t M0 = MPerBlock / M1; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1>>, + tuple, sequence<1>>, + sequence<1, 2, 2>, + sequence<2, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() + { + using ODataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() + { + using OGradDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kKPerBlock = Problem::kVHeaddim; + + return MakePreXDramTileDistribution(); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(QTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK3; + }(); + + constexpr index_t N1 = GetTransposedAlignmentQ(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackQ(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(KTLoadOnce) + return Problem::BlockFmhaShape::kN0; + else + return Problem::BlockFmhaShape::kK4; + }(); + + constexpr index_t N1 = GetTransposedAlignmentK(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = [&]() { + if constexpr(OGradTLoadOnce) + return Problem::BlockFmhaShape::kM0; + else + return Problem::BlockFmhaShape::kK1; + }(); + + constexpr index_t N1 = GetTransposedAlignmentOGrad(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackOGrad(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + static_assert(kMPerBlock == M0 * M1 * M2 * M3); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<3, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t N1 = GetTransposedAlignmentBias(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); // TODO: this is not always true? + constexpr index_t M3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackBias(); + static_assert(kKPack % M3 == 0); + constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave + constexpr index_t M1 = get_warp_size() / (M2 * N0); + constexpr index_t M0 = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 1>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<1, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() + { + using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); + return c_block_tensor_type::get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockGemmASmemBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmASmemBRegCRegV1CustomPolicy; + + return BlockGemmASmemBRegCRegV1{}; + } + + // template + // CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() + // { + // using BlockGemmProblem = + // BlockGemmPipelineProblem>; + // constexpr auto warp_gemm = []() { + // if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{}; + // } + // else if constexpr(std::is_same_v && + // std::is_same_v && + // std::is_same_v) + // { + // return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{}; + // } + // }(); + + // using BlockGemmPolicy = + // BlockGemmASmemBSmemCRegV1CustomPolicy; + + // return BlockGemmASmemBSmemCRegV1{}; + // } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV1CustomPolicy; + return BlockGemmARegBSmemCRegV1{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() + { + using BlockGemmProblem = + BlockGemmPipelineProblem>; + + using WarpGemm = + WarpGemmMfmaDispatcher{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), + true>; + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a54a9fcb32f6bb7afd6a10293c9da98e91756e0a --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockFmhaBwdPipelineEnum +{ + KSKTSVR = 0, + QSKSVROGradS, + KSVR, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b787e9f36dcd38a914688703cf80c279647238f --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockFmhaBwdPipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using GemmDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using QGradDataType = remove_cvref_t; + using KGradDataType = remove_cvref_t; + using VGradDataType = remove_cvref_t; + using BiasGradDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr bool kHasDropout = Traits::kHasDropout; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +template +struct BlockFmhaBwdOGradDotOPipelineProblem +{ + using ODataType = remove_cvref_t; + using OGradDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, + "kBlockSize should be divisible by get_warp_size()"); + + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kVHeaddim = kVHeaddim_; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7efdb798cb21efd9981e518aa2a44a80a74f9fc1 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { +namespace detail { +template +struct log2; + +template <> +struct log2<16> : std::integral_constant +{ +}; + +template <> +struct log2<32> : std::integral_constant +{ +}; + +template <> +struct log2<64> : std::integral_constant +{ +}; + +template <> +struct log2<128> : std::integral_constant +{ +}; +} // namespace detail + +template +struct BlockFmhaFwdSplitKVCombinePipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kHeadDimV = Problem::kHeadDimV; + static constexpr index_t kM0 = Problem::kM0; + static constexpr index_t kN1 = Problem::kN1; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr index_t kMaxSplits = Problem::kMaxSplits; + + static constexpr index_t kAlignmentLSE = + kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE(); + static constexpr index_t kAlignmentLSEacc = kAlignmentLSE; + + static constexpr index_t kAlignmentOacc = + kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kHeadDimV <= 32) + { + constexpr std::array occupancy{3, 3, 3, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 128) + { + constexpr std::array occupancy{3, 3, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + else if constexpr(kHeadDimV <= 256) + { + constexpr std::array occupancy{2, 2, 2, 1}; + return occupancy[detail::log2::value - 4]; + } + } + }(); + + static constexpr const char* name = "unused"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, + const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, + const LSEElementFunction& lse_element_func, + const OaccElementFunction& o_acc_element_func, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + // lse_acc tile in LDS + LSEDataType* lse_acc_lds_ptr = + static_cast(static_cast(static_cast(smem_ptr))); + auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor()]( + index_t row, index_t col) -> LSEDataType& { + return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))]; + }; + + auto lse_acc_lds_write_window = [&]() { + auto view = make_tensor_view( + lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor()); + return make_tile_window(view, make_tuple(number{}, number{}), {0, 0}); + }(); + + auto lse_acc_dram_window = + make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(), + lse_acc_dram_block_window_tmp.get_window_lengths(), + lse_acc_dram_block_window_tmp.get_window_origin(), + Policy::template MakeLSEaccDramTileDistribution()); + + // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). + auto lse_acc_tile = load_tile(lse_acc_dram_window); + store_tile(lse_acc_lds_write_window, lse_acc_tile); + block_sync_lds(); + + auto lse_accum = make_static_distributed_tensor( + Policy::template MakeLSEaccRegTileDistribution()); + + // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) + // this will extend the distributed tensor width so that each thread in wave have data to + // reduce. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_accum(i_j_idx) = lse_acc_lds(row, col); + } + else + { + lse_accum(i_j_idx) = -numeric::infinity(); + } + }); + }); + } + + // compute the logsumexp of the LSE along the split dimension. + const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto lse_max = block_tile_reduce( + lse_accum, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(lse_max, f_max, bool_constant{}); + + static const auto get_validated_m = [](LSEDataType raw_m) { + return raw_m == -numeric::infinity() ? type_convert(0.f) + : raw_m; + }; + + decltype(lse_accum) lse_exp; + { + constexpr auto spans = decltype(lse_exp)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + lse_exp(i_j_idx) = + ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx))); + }); + }); + } + + auto lse_sum = block_tile_reduce( + lse_exp, sequence<1>{}, f_sum, type_convert(0)); + block_tile_reduce_sync(lse_sum, f_sum, bool_constant{}); + + decltype(lse_max) lse_logsum; + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx)) + { + lse_logsum(i_idx) = numeric::infinity(); + } + else + { + lse_logsum(i_idx) = + ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx)); + } + }); + } + + // store the lse scales in shared memory. + { + constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto x_indices = get_x_indices_from_distributed_indices( + lse_accum.get_tile_distribution(), i_j_idx); + + const auto col = x_indices.at(number<1>{}); + if(col < num_splits) + { + const auto row = x_indices.at(number<0>{}); + + lse_acc_lds(row, col) = + ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx)); + } + }); + }); + } + block_sync_lds(); + + if constexpr(kStoreLSE) + { + constexpr auto spans = decltype(lse_logsum)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(lse_logsum(i_idx) == numeric::infinity()) + { + lse_logsum(i_idx) = -numeric::infinity(); + } + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); + } + + auto o_acc_dist = Policy::template MakeOaccDramTileDistribution(); + auto o_acc_dram_window = + make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), + o_acc_dram_block_window_tmp.get_window_lengths(), + o_acc_dram_block_window_tmp.get_window_origin(), + o_acc_dist); + auto o_acc = make_static_distributed_tensor(o_acc_dist); + clear_tile(o_acc); + + const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; + + for(index_t i_split = 0; i_split < num_splits; ++i_split) + { + auto o_tile = load_tile(o_acc_dram_window); + { + constexpr auto spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices( + o_acc.get_tile_distribution(), i_j_idx); + + const auto row = x_indices.at(number<0>{}); + + const LSEDataType lse_scale = lse_acc_lds(row, i_split); + o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); + }); + }); + } + + move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); + } + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window, + const OaccDramBlockWindow& o_acc_dram_block_window, + LSEDramBlockWindow& lse_dram_block_window, + index_t num_splits, + index_t max_seqlen_q, + void* smem_ptr) const + { + return operator()(lse_acc_dram_block_window, + o_acc_dram_block_window, + lse_dram_block_window, + identity{}, + identity{}, + num_splits, + max_seqlen_q, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2eb092f055217b79435ee1d8d246ca4086635a90 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() + { + using LSEDataType = remove_cvref_t; + return 16 / sizeof(LSEDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc() + { + using OaccDataType = remove_cvref_t; + return 16 / sizeof(OaccDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using ODataType = remove_cvref_t; + return 16 / sizeof(ODataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return sizeof(typename Problem::LSEDataType) * + MakeLSEaccLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t kMPerBlock = Problem::kMaxSplits; + + constexpr index_t NPerThread = 16 / sizeof(LSEDataType); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; + constexpr index_t TotalWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // 3d + padding, [kMaxSplits, kM0] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lse_acc_lds_block_desc; + } + + // 3d + padding, [kM0, kMaxSplits] + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() + { + using LSEDataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::kMaxSplits; + constexpr index_t kNPerBlock = Problem::kM0; + constexpr index_t NPack = 16 / sizeof(LSEDataType); + + constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * NPack>{}, number{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( + lse_acc_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lse_acc_t_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kMPerBlock = Problem::kM0; + + constexpr index_t NThreads = get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / NThreads; + + constexpr index_t MThreads = kBlockSize / NThreads; + constexpr index_t MPerThread = kMPerBlock / MThreads; + + static_assert(NThreads * NPerThread == kNPerBlock); + static_assert(MThreads * MPerThread == kMPerBlock); + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution() + { + using OaccDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kM0; + constexpr index_t kNPerBlock = Problem::kN1; + + constexpr index_t N1 = 16 / sizeof(OaccDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t M2 = get_warp_size() / N0; + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6d74b38851ae7bc5728bca7a27323b8453eadf6 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -0,0 +1,666 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaFwdSplitKVPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + dropout.Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + if constexpr(kStoreLSE) + { + // store lse acc + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ae363a4978587a1458bcc4a817efc043ac174264 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = true; // always store LSE (acc) + static constexpr bool kHasDropout = false; // ignore this flag + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kK0BlockLength <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kK0BlockLength <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kK0BlockLength <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile + const LSEaccElementFunction& lse_acc_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), + {0, 0}); + }, + number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if masked and no work to do. + if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse_acc = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse_acc, -numeric::infinity()); + + store_tile(lse_acc_dram_window_tmp, + tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = make_tile_window( + bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kK0BlockLength / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + + /// TODO: only check in last iteration without increasing code size + if constexpr(kHasUnevenSplits) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, + -numeric::infinity(), + [&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return seqlen_k_end_ <= col; + }); + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse acc + if constexpr(kStoreLSE) + { + auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); + sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } +#else + lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile + index_t num_splits, + index_t i_split, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + BlockDropout& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_acc_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + num_splits, + i_split, + mask, + position_encoding, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6109fa5ab9566e6bd801692b7b06e03fe2bb919f --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..338319ab32161de5d26b10ca6fd73be1ac7b0edf --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 5500174086946211817ea64544d9784994cc748c..cf70dff63fa3f62373cd3f176a3cf5c43fff9817 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum QSKSVS, }; +template +struct BlockFmhaPipelineEnumToStr; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qs"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9d27b2df687f6e85fc3615f779255456759ab377..23b75f16ac10a5c847919c9a02914043b1804981 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -13,6 +13,7 @@ template struct BlockFmhaPipelineProblem { - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using BlockFmhaShape = remove_cvref_t; - using FmhaMask = remove_cvref_t; - using Traits = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; @@ -45,10 +47,76 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasBias = Traits::kHasBias; + static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem +{ + static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; +}; + +template +struct BlockFmhaSplitKVCombinePipelineProblem +{ + using LSEDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kBlockSize = 256; + static constexpr bool kIsGroupMode = kIsGroupMode_; + + static constexpr index_t kHeadDimV = HeadDimV_; + static constexpr index_t kM0 = kM0_; + static constexpr index_t kN1 = kN1_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr index_t kMaxSplits = Traits::kMaxSplits; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e239bb91648f16b1bb8f912b70581db8f56372d..a392f0124dba7978e858f4cabaf6ce306cd14d4f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -13,19 +15,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -46,8 +49,9 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -82,7 +86,7 @@ struct BlockFmhaPipelineQRKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -96,6 +100,8 @@ struct BlockFmhaPipelineQRKSVS static constexpr const char* name = "qr"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -105,6 +111,7 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -113,7 +120,8 @@ struct BlockFmhaPipelineQRKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -123,14 +131,17 @@ struct BlockFmhaPipelineQRKSVS const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -237,6 +248,9 @@ struct BlockFmhaPipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -270,13 +284,13 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -322,7 +336,7 @@ struct BlockFmhaPipelineQRKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -338,6 +352,25 @@ struct BlockFmhaPipelineQRKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -382,7 +415,8 @@ struct BlockFmhaPipelineQRKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -403,7 +437,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -427,7 +462,8 @@ struct BlockFmhaPipelineQRKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -450,6 +486,12 @@ struct BlockFmhaPipelineQRKSVS }); }); + if constexpr(kHasDropout) + { + dropout.template Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + block_sync_lds(); if constexpr(std::is_same_v) { @@ -519,7 +561,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -563,16 +606,21 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename RandValDramBlockWindowTmp, + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -582,14 +630,17 @@ struct BlockFmhaPipelineQRKSVS identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, identity{}, identity{}, mask, + position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 0573b50d0473f95f83c4b1738f97afb2c69773ba..8251627e6c8b1eb20a421f64b18983c935aa4b8a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -1,11 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -14,19 +16,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQRKSVSAsync { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -51,8 +54,9 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -77,23 +81,30 @@ struct BlockFmhaPipelineQRKSVSAsync return Problem::kBlockPerCu; else { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + if constexpr(kK0BlockLength <= 32) { - if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) return 1; else return 2; } else if constexpr(kK0BlockLength <= 64) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 2; else return 3; } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -107,6 +118,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr const char* name = "qr_async"; + using DropoutType = std::conditional_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -116,6 +129,7 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, + typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, @@ -124,7 +138,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -134,14 +149,17 @@ struct BlockFmhaPipelineQRKSVSAsync const VElementFunction& v_element_func, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile const LSEElementFunction& lse_element_func, const SAccElementFunction& s_acc_element_func, const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { static_assert( std::is_same_v> && @@ -208,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well @@ -247,8 +266,8 @@ struct BlockFmhaPipelineQRKSVSAsync const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) { if(num_total_loop <= 0) { @@ -281,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -288,6 +318,9 @@ struct BlockFmhaPipelineQRKSVSAsync {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -295,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -318,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); + k_dram_window, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -367,7 +402,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -383,6 +418,25 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -463,8 +517,9 @@ struct BlockFmhaPipelineQRKSVSAsync static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -485,7 +540,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -509,7 +565,8 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -532,8 +589,25 @@ struct BlockFmhaPipelineQRKSVSAsync }); }); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); // STAGE 3, KV gemm if constexpr(k1_loops > 1) @@ -583,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync { // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail @@ -617,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); } @@ -661,16 +733,21 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename RandValDramBlockWindowTmp, + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -680,14 +757,17 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, bias_dram_block_window_tmp, identity{}, + randval_dram_block_window_tmp, lse_dram_block_window_tmp, identity{}, identity{}, identity{}, identity{}, mask, + position_encoding, scale_s, - smem_ptr); + smem_ptr, + dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 0e59ee6fe0621713e25cecab20f42338d101f581..f4767de0e95769f9da11bfc68ea47cea11dfc4e2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -13,19 +14,20 @@ namespace ck_tile { template struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -46,8 +48,9 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -82,7 +85,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -105,18 +108,23 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename RandValDramBlockWindowTmp, + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported + LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, + PositionEncoding /*position_encoding*/, float scale_s, float descale_qk, float descale_sv, - void* smem_ptr) const + void* smem_ptr, + BlockDropout& /*dropout*/) const // not supported { static_assert( std::is_same_v> && @@ -249,13 +257,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -300,7 +308,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -356,7 +364,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -377,7 +386,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -401,7 +410,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 677c05769cbbc02c2a3872b31483630132c009fa..bc9ca93d09d9a50a892029c2ef1ae24254d7991b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" namespace ck_tile { @@ -12,19 +13,20 @@ namespace ck_tile { template struct BlockFmhaPipelineQSKSVS { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using QDataType = remove_cvref_t; - using KDataType = remove_cvref_t; - using VDataType = remove_cvref_t; - using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; - using BiasDataType = remove_cvref_t; - using LSEDataType = remove_cvref_t; - using PDataType = remove_cvref_t; - using OaccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using FmhaMask = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; using VLayout = remove_cvref_t; @@ -45,7 +47,7 @@ struct BlockFmhaPipelineQSKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr index_t kBlockPerCu = []() { @@ -63,7 +65,7 @@ struct BlockFmhaPipelineQSKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -99,7 +101,8 @@ struct BlockFmhaPipelineQSKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -115,6 +118,7 @@ struct BlockFmhaPipelineQSKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -265,13 +269,13 @@ struct BlockFmhaPipelineQSKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -313,7 +317,7 @@ struct BlockFmhaPipelineQSKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -329,6 +333,25 @@ struct BlockFmhaPipelineQSKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -373,7 +396,8 @@ struct BlockFmhaPipelineQSKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -394,7 +418,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -418,7 +443,8 @@ struct BlockFmhaPipelineQSKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -510,7 +536,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -554,7 +581,8 @@ struct BlockFmhaPipelineQSKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -562,6 +590,7 @@ struct BlockFmhaPipelineQSKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -579,6 +608,7 @@ struct BlockFmhaPipelineQSKSVS identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 4fda6f008fb8ff0e0a22f8fc69b845078164f8e7..12af81bb9878205dc26e39513f998bbe72a5561c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,13 +89,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -212,13 +212,13 @@ struct BlockFmhaPipelineQXCustomPolicy std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution{}; + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; } else if constexpr(std::is_same_v && std::is_same_v && @@ -691,7 +691,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { // TODO: assume Q is in register // TODO: assume K/V has same data type @@ -702,6 +702,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + if constexpr(AsyncCopyK) + { + return GetSmemSizeKV() + GetSmemSizeDropout(); + } + else + { + return ck_tile::max(GetSmemSizeKV(), GetSmemSizeDropout()); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + if constexpr(Problem::kHasDropout) + { + constexpr auto gemm_0 = QXPolicy::template GetQKBlockGemm(); + constexpr auto config = + decltype(gemm_0)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t kMPerStep = MWarp * WG::kM; + constexpr index_t kNPerStep = WG::kN; + + return (kMPerStep + 1) * kNPerStep * sizeof(uint8_t); + } + else + { + return 0; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index d8a290b09cd372a36805797471d1531be3e31c54..64a61e94d1b4946c9b7018c1608114ed4c0dc0e5 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -43,4 +43,53 @@ struct TileFmhaShape ck_tile::tensor_layout::gemm::ColumnMajor>; }; +template +struct TileFmhaBwdShape +{ + using BlockTile = remove_cvref_t; + using Gemm0BlockWarps = remove_cvref_t; + using Gemm0WarpTile = remove_cvref_t; + using Gemm1BlockWarps = remove_cvref_t; + using Gemm1WarpTile = remove_cvref_t; + using Gemm2BlockWarps = remove_cvref_t; + using Gemm2WarpTile = remove_cvref_t; + using Gemm3BlockWarps = remove_cvref_t; + using Gemm3WarpTile = remove_cvref_t; + using Gemm4BlockWarps = remove_cvref_t; + using Gemm4WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{}); + + static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) && + NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{})); + + static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen + static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen + static constexpr index_t kK0 = + BlockTile::at(number<2>{}); // tile size along gemm0(Q@K^T) unroll + static constexpr index_t kK1 = + BlockTile::at(number<3>{}); // tile size along gemm1(P^T@dO) unroll + static constexpr index_t kK2 = + BlockTile::at(number<4>{}); // tile size along gemm2(dO@V^T) unroll + static constexpr index_t kK3 = + BlockTile::at(number<5>{}); // tile size along gemm3(dS^T@Q) unroll + static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll + static constexpr index_t kQKHeaddim = + BlockTile::at(number<7>{}); // Q & K headdim, used for pipeline that need load Q/Q^T or + // K/K^T at once + static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline + // that need load V at once +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 137f4ddd8190d63f4715b074aad721043bb9407f..a59431e39d3d2b444cd594176cfa341a6e300233 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" namespace ck_tile { @@ -11,8 +12,10 @@ template struct TileFmhaTraits @@ -21,10 +24,66 @@ struct TileFmhaTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaFwdSplitKVTraits : TileFmhaTraits +{ + // determine if some split (length) is not divisible by tile size + static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; +}; + +template +struct TileFmhaFwdSplitKVCombineTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + + static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_); + static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0); + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + +template +struct TileFmhaBwdOGradDotOTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c7ebcf9606460907db30419107acfc9202b6ac76..a89536e6ebe5b19d5cee5c8ef750aa33eecc2f28 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -3,20 +3,21 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp deleted file mode 100644 index 1053c751ad5b3de48bb41d1f5e7598187b6e7979..0000000000000000000000000000000000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -// Problem Description for BlockGemmARegBGmemCReg -template -struct BlockGemmARegBGmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index 7799bbe918e79771641812bf7044dee905075223..f097790ae6e1eb0c4a38739edaeca06be07cfc83 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,7 +28,7 @@ struct BlockGemmARegBGmemCRegV1 // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< - BlockGemmARegBSmemCRegProblem, + BlockGemmProblem, BlockGemmARegBSmemCRegV1DefaultPolicy>; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp index 4156398bd3f83c0647816fcc42bb992757da7ed8..0a17b053537f4509b359b3d2429e0d5f40a1733b 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index aac9c4f5521ff9969386b1a299cf6b439e21417f..84883d6ed8449ea1fe2236a8f886e374f9d2b312 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" namespace ck_tile { @@ -35,13 +35,16 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -181,23 +184,10 @@ struct BlockGemmARegBSmemCRegV1 }); } - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const { - static_assert( - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && - KPerBlock == BlockGemmShape::kK, - "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); @@ -208,20 +198,7 @@ struct BlockGemmARegBSmemCRegV1 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - const index_t iNWarp = get_warp_id() % NWarp; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, @@ -231,108 +208,20 @@ struct BlockGemmARegBSmemCRegV1 sequence<1, 2>, sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } - // constrcut from A-block-tensor from A-Block-tensor-tmp - // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent - // distribution - auto a_block_tensor = - make_static_distributed_tensor(a_block_dstr); - - a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); - -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; - - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // Construct C-Block-HostTensor - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - using AWarpDstr = typename WG::AWarpDstr; - using CWarpDstr = typename WG::CWarpDstr; - - using AWarpTensor = typename WG::AWarpTensor; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // hot loop: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); return c_block_tensor; } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp index 779113d96a6cedd56b3590ccacfb4de601282fa0..f998c67c952cfa46eebc4738e5d062e7a453ba9e 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 8073989264b0f12356e0baf1f8f4879164dc40fb..9b10d435b67a5b9c3870c8493583b15540925fee 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 405d7f1258441c0cb7632b5961dac3ccb09bb0d5..4a82702c1ffd812900e5748cea786c1e78d3b878 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp index 8bcd04b7b0b6ea996bb146bdeb2ad0ae6944d5a9..20dcf2c270e622863866a8ed4bc71f1e5e5d0834 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp index c17385b8e56d5744f9234484618809c54feb3429..e90500c28c3c5601653a84e5f3dafdaa117d1672 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65ce1a9b8f107bf3e30aca978e7cff10fba12aaa --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockGemmASmemBRegCRegV1 +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; + // constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + // KPerBlock == BlockGemmShape::kK, + // "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode); + + // constrcut from B-block-tensor from B-Block-tensor-tmp + // FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent + // distribution + auto b_block_tensor = + make_static_distributed_tensor(b_block_dstr); + + b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer(); + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using BWarpTensor = typename WG::BWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + CK_TILE_DEVICE constexpr auto MakeCBlockTile() const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, + const BBlockTensorTmp& b_block_tensor_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a17578f69e30e0636da09776fc407bd044f18f4 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockGemmASmemBRegCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd16f09c375f937ebefd723eade767b2adb86c2c --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBRegCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBRegCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + // FIXME + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 4, 1); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp deleted file mode 100644 index ed772891a43df820b41321b7136016eb6e8d6515..0000000000000000000000000000000000000000 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { - -// Problem Description for BlockGemmASmemBSmemCRegV1 -template -struct BlockGemmASmemBSmemCRegProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - - static constexpr index_t kBlockSize = kBlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 40da16d820f31c5c1e72c42ee13040be5c2f7a6b..ac4522170947aa0e3507039502351f902599a180 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 319711088f4a24b96b923d647be0c1bac6dcd1e3..2436457ec1de31e6a7bce32a3d9d1515c52ee739 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index fbb957727d6dff23810d008f2820ff3d03808717..f798d6e815f928ec9e3800215f285640ab88f955 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp similarity index 88% rename from include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp rename to include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index 7a0390a8a200ccb465f4a0dc118ec1ac91ad1d93..d8f66c81caf25dcb8c86fd258abf43fbd4c482e1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -7,13 +7,13 @@ namespace ck_tile { -// Problem Description for BlockGemmARegBSmemCReg +// Problem Description for BlockGemm template -struct BlockGemmARegBSmemCRegProblem +struct BlockGemmProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index dfc63f04c69382b859b5e7a0558b46ca22523d45..5b4419b79f5c856b42b581756a6f911bb39bb0b0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -38,7 +41,7 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; -using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; @@ -56,6 +59,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< + WarpGemmAtrributeMfmaIterateK_SwizzleA>; + using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< WarpGemmAtrributeMfmaTransposedCDistribution>; @@ -72,7 +78,7 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; -using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 71c59bbd17521373608466306ca1a9452752b666..fd5b004d362b9478356cf0ec76cbbe05f7780eed 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -468,4 +468,92 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB } }; +template +struct WarpGemmAtrributeMfmaIterateK_SwizzleA +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = + ext_vector_t::vector_size * kKIter>; + using BVecType = + ext_vector_t::vector_size * kKIter>; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK * kKIter; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void + operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + auto c_vec = Impl{}( + reinterpret_cast(a_vec).template get_as()[I0], + reinterpret_cast(b_vec).template get_as()[I0]); + + static_for<1, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter]); + }); + + return c_vec; + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cb250516f463ba535815417cefa85fc37b917362..dd164e72ea72b373cebacea7937e1fd4eced8448 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else @@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else @@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b66645ed4b882b335d8e908f8e9a06823048dc0 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp" +#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4be3e56874860a004e327d6177f9996c526150af --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" + +namespace ck_tile { + +// TODO: Extract some type to wrapper class +template +struct Layernorm2dFwd +{ + using Problem = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using BetaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using MeanDataType = ck_tile::remove_cvref_t; + using InvStdDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kHasBeta = !std::is_same_v; + static constexpr bool kSaveMean = !std::is_same_v; + static constexpr bool kSaveInvStd = !std::is_same_v; + + static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; + + static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + const void* p_beta; + + void* p_y; + void* p_mean; + void* p_invStd; + + float epsilon; + + ck_tile::index_t M; + ck_tile::index_t N; + }; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_mean, + void* p_invStd, + float epsilon, + ck_tile::index_t M, + ck_tile::index_t N) + { + return Kargs{p_x, p_gamma, p_beta, p_y, p_mean, p_invStd, epsilon, M, N}; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t M) { return M / kMPerBlock; } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; } + + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 1>>, + sequence<1, 2>, + sequence<2, 2>>{}); + } + + CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 1>>, + sequence<1>, + sequence<2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr) + { + constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>(); + + using Lengths = decltype(nDstrSpan.impl_); + + ck_tile::index_t ret = 1; + + ck_tile::static_for<0, Lengths::size(), 1>{}( + [&](auto idx) { ret *= Lengths::template at(idx); }); + + return ret; + } + + template + CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, + const ComputeDataType epsilon) + { + // TODO: Investigate fast inverse square root algorithm with epsilon + constexpr auto spans = DistributedTensor::get_distributed_spans(); + + DistributedTensor out_dstr_tensor; + + sweep_tile_span(spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + out_dstr_tensor(i_idx) = type_convert(1.0f) / + ck_tile::sqrt(in_dstr_tensor[i_idx] + epsilon); + }); + + return out_dstr_tensor; + } + + template + CK_TILE_DEVICE std::enable_if_t TwoPassLayernorm2dFwd(const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + MeanDataType* p_mean, + InvStdDataType* p_invStd, + const ComputeDataType epsilon, + ck_tile::index_t M, + ck_tile::index_t N) const + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + const auto gamma_n = make_naive_tensor_view( + p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto beta_n = make_naive_tensor_view( + p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{}); + + const auto iM = get_block_id() * kMPerBlock; + + constexpr auto xDstr = MakeXBlockTileDistribution(); + + auto x_block_window = make_tile_window( + x_m_n, make_tuple(number{}, number{}), {iM, 0}, xDstr); + + index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock); + + // TODO: padding - handle max_count if N % kNPerBlock != 0 + constexpr auto NPerThread = GetNPerThread(xDstr); + ThreadWelford thread_welford{ + type_convert(NPerThread * N / kNPerBlock)}; + + using XTensorType = decltype(load_tile(x_block_window)); + auto mean_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + auto var_compute_block_tensor = + thread_welford.template MakeInitialMeanVarDistributedTensor(); + + clear_tile(mean_compute_block_tensor); + clear_tile(var_compute_block_tensor); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + + thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); + move_tile_window(x_block_window, {0, kNPerBlock}); + } + + // TODO: support cross warp Welford + WarpMergeWelford{}( + mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); + + auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); + + if constexpr(kSaveMean) + { + const auto mean_m = make_naive_tensor_view_packed( + p_mean, make_tuple(M), number<32>{}); + + auto mean_block_window = + make_tile_window(mean_m, make_tuple(number{}), {iM}); + + store_tile(mean_block_window, cast_tile(mean_compute_block_tensor)); + } + if constexpr(kSaveInvStd) + { + const auto inv_std_m = make_naive_tensor_view_packed( + p_invStd, make_tuple(M), number<32>{}); + + auto inv_std_block_window = + make_tile_window(inv_std_m, make_tuple(number{}), {iM}); + + store_tile(inv_std_block_window, cast_tile(inv_std_compute_block_tensor)); + } + + // TODO: Extract normalize pipeline + const auto y_m_n = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{}); + + auto y_block_window = make_tile_window( + y_m_n, make_tuple(number{}, number{}), {iM, 0}); + + constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution(); + constexpr auto betaDstr = gammaDstr; + + auto gamma_block_window = + make_tile_window(gamma_n, make_tuple(number{}), {0}, gammaDstr); + + auto beta_block_window = make_tile_window( + beta_n, make_tuple(number{}, number{}), {0}, betaDstr); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = N - kNPerBlock; + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {stride_to_right_most_window}); + move_tile_window(beta_block_window, {stride_to_right_most_window}); + move_tile_window(y_block_window, {0, stride_to_right_most_window}); + + // Normalization + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x_block_tensor = load_tile(x_block_window); + const auto gamma_block_tensor = load_tile(gamma_block_window); + const auto beta_block_tensor = load_tile(beta_block_window); + + constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); + + auto y_block_tensor = + make_static_distributed_tensor(x_block_tensor.get_tile_distribution()); + + sweep_tile_span(x_spans[I1], [&](auto idx1) { + constexpr auto j_idx = make_tuple(idx1); + const auto gamma = type_convert(gamma_block_tensor[j_idx]); + const auto beta = type_convert(beta_block_tensor[j_idx]); + + sweep_tile_span(x_spans[I0], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + const auto mean = mean_compute_block_tensor[i_idx]; + const auto inv_std = inv_std_compute_block_tensor[i_idx]; + + const auto x = type_convert(x_block_tensor[i_j_idx]); + auto y = (x - mean) * inv_std * gamma + beta; + + y_block_tensor(i_j_idx) = type_convert(y); + }); + }); + + store_tile(y_block_window, y_block_tensor); + + move_tile_window(x_block_window, {0, -kNPerBlock}); + move_tile_window(gamma_block_window, {-kNPerBlock}); + move_tile_window(beta_block_window, {-kNPerBlock}); + move_tile_window(y_block_window, {0, -kNPerBlock}); + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + TwoPassLayernorm2dFwd(static_cast(kargs.p_x), + static_cast(kargs.p_gamma), + static_cast(kargs.p_beta), + static_cast(kargs.p_y), + static_cast(kargs.p_mean), + static_cast(kargs.p_invStd), + static_cast(kargs.epsilon), + kargs.M, + kargs.N); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5206d36d7d2ebdd81fadbcc3c8d4d06a47387a93 --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct BlockLayernorm2dFwdProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using BetaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using MeanDataType = remove_cvref_t; + using InvStdDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1ff541d844d3624e890bab30e56477e6cc373b8e --- /dev/null +++ b/include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template // Sequence<... +struct TileLayernorm2dShape +{ + static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); + static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); + + static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); + + static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; + static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + + static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); + static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); + + static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; + static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; + + // TODO - kNNumWarps can only be 1 if we don't support cross warp welford + static_assert(kNWarpPerBlock == 1); + + static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kNWarpPerBlock; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dffaad75017e6f4d39586263172eed1f986cda84 --- /dev/null +++ b/include/ck_tile/ops/welford.hpp @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/welford/warp/warp_welford.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ca9a23657437a490a350a225ad17cd099d99269 --- /dev/null +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct ThreadWelford +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE void Update(T& mean, T& var, T x) + { + if(ck_tile::isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + // [CAUSION] - max_count_ is to deal with the padding problem + // max_count_ is depend on caller, eg: naive and splitN welford will have different + // calculation of max_count_ + CK_TILE_DEVICE constexpr ThreadWelford(int max_count) : cur_count_(0), max_count_(max_count) {} + + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + Update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x); + }); + } + }); + } + + template + CK_TILE_DEVICE static auto MakeInitialMeanVarDistributedTensor() + { + static_assert(std::is_same_v, "wrong!"); + + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + clear_tile(tensor); + + return tensor; + } + + template + CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor) + { + auto mean_tensor = MakeInitialMeanVarDistributedTensor(); + auto var_tensor = MakeInitialMeanVarDistributedTensor(); + + (*this)(x_tensor, mean_tensor, var_tensor); + + return ck_tile::make_tuple(mean_tensor, var_tensor); + } + + int cur_count_; + int max_count_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/warp/warp_welford.hpp b/include/ck_tile/ops/welford/warp/warp_welford.hpp new file mode 100644 index 0000000000000000000000000000000000000000..687b61f430d0e68178eb494f94425c777fd23b9c --- /dev/null +++ b/include/ck_tile/ops/welford/warp/warp_welford.hpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct WarpMergeWelford +{ + using ComputeDataType = remove_cvref_t; + + template + CK_TILE_DEVICE static void + Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + { + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a_ * count_b_over_count; + count_a = count; + } + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + const auto rs_idx = + mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + constexpr index_t lid_delta = + lid_over_rid_derivative * (1 << (nstage - istage - 1)); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_down(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_down(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_down(v_local_count, lid_delta); + + // welford merge + Merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count); + }); + } + }); + + // cross-lane broadcast for replication + // only broadcast on R dimension correspond to lane + // (lane id maps to this R dimension) + if constexpr(BroadcastLane) + { + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + const index_t r_id = rs_idx[idim_r]; + + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // broadcast sweep backward + static_for<0, nstage, 1>{}([&](auto istage) { + // do I hold reduced data? + const bool do_i_hold_reduced_data = r_id < (1 << istage); + + constexpr index_t lid_delta = lid_over_rid_derivative * (1 << istage); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle_up(v_local_mean, lid_delta); + const auto v_remote_var = warp_shuffle_up(v_local_var, lid_delta); + const auto v_remote_count = warp_shuffle_up(v_local_count, lid_delta); + + // decide whether to update local data with remote data + v_local_mean = do_i_hold_reduced_data ? v_local_mean : v_remote_mean; + v_local_var = do_i_hold_reduced_data ? v_local_var : v_remote_var; + v_local_count = do_i_hold_reduced_data ? v_local_count : v_remote_count; + }); + } + }); + } + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + + if constexpr(GetActualVariance) + var_tensor.get_thread_buffer()(i) = v_local_var / v_local_count; + else + var_tensor.get_thread_buffer()(i) = v_local_var; + + count = v_local_count; + }); + } +}; + +} // namespace ck_tile diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19fa6c209f003f1ee24d03acd6f1acf0b12fddcf --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); +#endif +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemm_Streamk_V2; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41303d2e9513ce4af45366e2f9c931bd8874b92f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index dfd3216441aa683e47421bc765a2f7fc2e7dbef1..8b830d91d546a421d5ca092758bef5fab00a569c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -86,6 +86,7 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std: //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7490ef22308c45ceb067e2059112fe41125f79bb --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Compute friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2388c4db0b695267aedc9d0fd9c9e5552f3608f1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..96baf6bb00db0a78647544ed5ffdca5e694093a4 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3bec17514950d3bff849386b249d6f88a4c0b12 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_outelementop_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, F8> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_outelementop_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute Type| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| Compute| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| TypeA| TypeB| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F8, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index dc56b8f4b2ffd26cc3dbdcbbd59ab2c2a75f690c..5a703e5814d15559dc40978643fa44840f3525e6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -352,6 +352,10 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -192,6 +216,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 8602a82ff87b3075e8bc5405c3e96523288c9116..0233d6d85c5b0afff95c86e58f17b28f828d2549 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -17,6 +17,10 @@ #endif #ifdef CK_USE_XDL #include "grouped_convolution_forward_xdl.inc" +#include "grouped_convolution_forward_xdl_merged_groups.inc" +#include "grouped_convolution_forward_comp_xdl.inc" +#include "grouped_convolution_forward_mem_inter_xdl.inc" +#include "grouped_convolution_forward_mem_intra_xdl.inc" #endif #ifdef CK_USE_WMMA #include "grouped_convolution_forward_wmma.inc" @@ -182,7 +186,7 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } #endif } @@ -196,6 +200,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -204,6 +215,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -214,6 +232,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + op_ptrs); } #endif } @@ -266,6 +291,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); } #endif @@ -315,6 +347,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -325,6 +364,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_INT8 @@ -369,6 +415,17 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc new file mode 100644 index 0000000000000000000000000000000000000000..c93d6f4413c124816a9a8a69208bc9db00256516 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e7c24f884c68a607d7dd47fdc44521b9bd2291fa --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63dcdc6053695eebc4e86c90ce742fedf51b78d5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances); +#endif + +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + op_ptrs); + } + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + op_ptrs); + } + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ad86d066f75b20f35fbd572a08e00cb5aa575f95 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + op_ptrs); + } +#endif + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc new file mode 100644 index 0000000000000000000000000000000000000000..b3913443d3d37ecdcd91cd818b948343984b4918 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc new file mode 100644 index 0000000000000000000000000000000000000000..6874822e71323c8faed7561e4f7a83054800fb85 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index e627d428deca12b6dcd7818fc15122d39d1f0cbd..aaac9a2af29d616195cac78fa325333e0f610d3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -75,7 +75,7 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( #ifdef CK_ENABLE_BF16 // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp similarity index 55% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp index f7c03177666ce5549ffb1a61d8419b8f03cbfeae..3298ad940c9bf95fa25bf6efe73da7adf25ff6b8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp @@ -17,7 +17,150 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + op_ptrs); + add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( op_ptrs); } } @@ -132,7 +296,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { @@ -199,7 +362,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { @@ -266,7 +428,6 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; - // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 4e075df43b0ffb0cd999ab332a75131e3d741e82..3336041354ae92f9218a080b84851234ba171c48 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -133,5 +133,40 @@ struct FillConstant } }; +template +struct TransformIntoStructuralSparsity +{ + // clang-format off + static constexpr T valid_sequences[] = { + 0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + }; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::for_each(first, last, [=, idx = 0](T& elem) mutable { + auto tmp_idx = idx; + idx += 1; + return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + } // namespace utils } // namespace ck diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index ddbd16ad9a10f01252011da3339f46e512262163..493b992acac02022ea02daa604933110d326a890 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -43,7 +43,15 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) first = false; else os << delim; - os << static_cast(v); + + if constexpr(std::is_same_v || std::is_same_v) + { + os << ck::type_convert(v); + } + else + { + os << static_cast(v); + } } return os; } diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c035e7e5641e698235a35bd1ad97c5791e764fd6..2081422e3a1071f60f4236aa939516eff2b4d5e3 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build XDL instances if gfx9 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT INST_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") message("removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + set(INST_OBJ) + foreach(source IN LISTS ARGN) + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + if(source MATCHES "_xdl") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set(offload_targets) + foreach(target IN LISTS INST_TARGETS) + string(APPEND offload_targets "--offload-arch=${target} ") + endforeach() + set_source_files_properties(${source} PROPERTIES COMPILE_FLAGS ${offload_targets}) + list(APPEND INST_OBJ ${source}) + endforeach() + add_library(${INSTANCE_NAME} OBJECT ${INST_OBJ}) target_compile_features(${INSTANCE_NAME} PUBLIC) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(${INSTANCE_NAME}) @@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list}) if(NOT DEFINED DTYPES) set(add_inst 1) endif() + + if(INSTANCES_ONLY) + set(INST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(INST_TARGETS ${GPU_TARGETS}) + endif() + + if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8")) message("quantization instances will not be built!") set(add_inst 0) @@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") set(add_inst 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt index 7c22d86810df8bdf4f345179b440266b18f9cb8b..5af7322b1ab16ecb3402c83fe99f07fb676379e4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -2,9 +2,14 @@ set(GEMM_MULTI_ABD_INSTANCES) list(APPEND GEMM_MULTI_ABD_INSTANCES + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp - + device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..573dcc7d76cbfd44faf51f355d4b5ef689d4e4cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6833ab20f83a7575f527fe4653b458da10c2fbc3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 98546de040e022e7248ba46542ebe869808e5ed6..7cbf55e5f955b902c69fee0e39ed5da363c15118 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( Interwave>{}); } -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); - - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..044fc26721dcbbbca45c219e455d0dd3d97f97fb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); + + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12bcf192527fa42a37042d5001e1a59ecfa91667 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4a04d48b3cf4dcf5d0af27de10c16e8fe2c3977 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 5c46730ea75dcab2256daa575710c25a6f331e3b..590e89284f7e71114146500999a07816e2abdd9e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i Interwave>{}); } -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyAdd>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - Multiply>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5741ee29ac1a6cd1812e4fdfc51ad64ff1581c66 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index 452a9c9632a9a25976017ecfae213ef735d8e857..f2eb52b49a39206c0ed57fcfb948ead5961324eb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -43,7 +43,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // Disable due to test failure + // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a930ab9ae020689c03a1ca39adeec6bb9d23fc6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e8d5c798ba7e780dc6ab13fde12519fe7211edd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6adcb8f4f42239320474e3fa0144d00a041f4cfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..631ae6872f54dde9cc6d175e561a0e5ce89bd162 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c49773a6588adbc6185edaa0b16ea90de1b1519 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39d54fb885ffe701a3687ad480101744a772a9e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ee50d63cb53ff55286ae449e8d0628bc044e59d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d31e0819a4795c4cdd0d3913513c1f248ac6f42a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe19f35e531b4e65dcb7f98a3d9dc6f6dc2ec9dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c1873b373cc90e05a1cb4fcabb05543ce576e24 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffd53f4069aa435a26c7a4226be8a68b66d3f576 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..094b8f92f8e09b60b7ee8ab3f3402d0ae2b5ece5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e00c1733e01cd40ebd01007a8acbebbf7e282734 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..546f909b3ce1d04f81b26a7d103368ab0d60026e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d91de96be34ff072bb9661bfda98b86d3b7d5477 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c70678b4492345aa7c2f518fa8d26bbda1555ca1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5410a0cc251d3c964d078b9fe5b86eda5bb52925 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ae7329f986902df2b028250cd7d8f6eceeac721 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fc5458a960cd32b113d7d37ec3530fc39d60729 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7369f87a577625f7338440f35f468c7988fcf8fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45425a41a13a014ee377a63a4a79455b9f90778b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b5ac0366ff0c288d991c7d913e96056dc2dd124 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53aa011a75e9df4aa4a73d1f261886fe20ffa1ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a930ab9ae020689c03a1ca39adeec6bb9d23fc6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,26 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_STREAMK_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp) + +add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e8d5c798ba7e780dc6ab13fde12519fe7211edd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6adcb8f4f42239320474e3fa0144d00a041f4cfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..631ae6872f54dde9cc6d175e561a0e5ce89bd162 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c49773a6588adbc6185edaa0b16ea90de1b1519 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39d54fb885ffe701a3687ad480101744a772a9e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ee50d63cb53ff55286ae449e8d0628bc044e59d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d31e0819a4795c4cdd0d3913513c1f248ac6f42a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe19f35e531b4e65dcb7f98a3d9dc6f6dc2ec9dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c1873b373cc90e05a1cb4fcabb05543ce576e24 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffd53f4069aa435a26c7a4226be8a68b66d3f576 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..094b8f92f8e09b60b7ee8ab3f3402d0ae2b5ece5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e00c1733e01cd40ebd01007a8acbebbf7e282734 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..546f909b3ce1d04f81b26a7d103368ab0d60026e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d91de96be34ff072bb9661bfda98b86d3b7d5477 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c70678b4492345aa7c2f518fa8d26bbda1555ca1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5410a0cc251d3c964d078b9fe5b86eda5bb52925 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ae7329f986902df2b028250cd7d8f6eceeac721 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fc5458a960cd32b113d7d37ec3530fc39d60729 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7369f87a577625f7338440f35f468c7988fcf8fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45425a41a13a014ee377a63a4a79455b9f90778b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b5ac0366ff0c288d991c7d913e96056dc2dd124 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53aa011a75e9df4aa4a73d1f261886fe20ffa1ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 340ddfb3f0f20aa43e34c7a3b1f571e1612d2922..f730534c8d48179eaa728780dbbcc70372b43ce7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -5,7 +5,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp + ) if(DL_KERNELS) list(APPEND GROUPED_CONV2D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15401f0e1bf5c462938eb504357db18de3fd4160 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..398c14b11cf09a18b5255b521cf3088e735cc7eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 1d3c3747d398e0cb46b16a59827333040764dfdd..170625a6a090c4ea070d6c97b58569a31582b2a6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,25 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + # merged groups + # NHWGC, GKYXC, NHWGK + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp + #mem + # NHWGC, GKYXC, NHWGK + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp + # NHWGC, GKYXC, NHWGK + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp + #comp + # NHWGC, GKYXC, NHWGK + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp #dl # GNHWC, GKYXC, GNHWK dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f0634735047bea75246ef5fbffdebb93e64552e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b1c7ef65e280b0fea733d9941afc19844c7d2bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93e07e08fbab81aa1e61e101ffe8ac54ee231b4e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 08770b8617fa4a7645651ec51347dce671147202..2afbfdc38668b19aa3aac1784b7eda318b670c3b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ae3fb5186432d9b0297722ab9838dfed5d6be45 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb7e91293648646c22da0045688515e4802f5e55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d787f4b048ce738c1d2a5dd060505a86f4b84d29 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5644289790391008293a91e89b2660c268de222a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b12dad5a3f1381f275df624b806ae6c1e1dfb78 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6fa4bc6e46fe4d83a349d542fbe68c49c0fac3d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fa56f48c77bcfcaad09e8ca5aad1ca44b8f5e54 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e226dae975bbdfa52c24143f0013867ba27e0c05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp index d46be53ba46cea2d870ee1fe7fbae6face71ee7e..3f191ab6bc54b847b5575091fe97bea249be9e59 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp @@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ BF8, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" // 1. Default add_device_operation_instances( instances, @@ -44,6 +46,10 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_ Empty_Tuple, NDHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); +#else +#pragma message "These instances were deprecated" + std::ignore = instances; +#endif } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 8b89dcf7ec5184d865a6f0fc799b3292cec96009..8e939c15a9a964efebaef87a7b1b22cedc981e46 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,11 +1,14 @@ -# XDL_DL_WMMA_KERNELS + # XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp + ) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d0f1e68cb9e0387727fcaa56243e2b592f89b7f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5cc062f2aaf5cf2cc3b35ce301bd2e640ca5b46 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp index 7f9493f602f5a9da84c7b43720f077c0a8f25f94..6e7f22b7e52aca04b6a01b425b0c6241bfbd20a7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ BF8, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" // 1. Default add_device_operation_instances( instances, @@ -41,6 +43,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_ GKZYXC, NDHWGK, ConvBwdWeightFilter1x1Stride1Pad0>{}); +#else +#pragma message "These instances were deprecated" + std::ignore = instances; +#endif } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 579bea00d8729823a47093187632f34d21ab968b..5be6672723beef7725dd85bcf89c79ba31b7dde3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -8,6 +8,23 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..efc464060349e6373e93e9d6f06c6954bde16e44 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f3cd4b7d2153143295550a5719c7c1257d7f91b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..386c622610d90544368faf6af2b06130250a4cde --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp index 4651c67a7de13600d8dec6b18af19c326aa38003..7b5ddf0a869dbd7f78338fae2135c8088935b1ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp @@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance PassThrough, F8>>>& instances) { +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, @@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance Empty_Tuple, NDHWGK, ConvFwd1x1S1P0>{}); +#else +#pragma message "These instances were deprecated" + std::ignore = instances; +#endif } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e7b4624b4ea15a6595ad3d17da5247417fdea19 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6fab8347d9b29f35ea5b1503ab3e4c5a78c53964 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f52f21454cfc73ac1d0b53f109c4dde96fff2e18 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5311888e96885f33e729bca591414af2403ef51 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9524378e040637f4b1496b5013cfd9bca1be5528 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49332076fdcdf014a2937db85d0e62c97a0690aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf1fcec9857c8de3f0fb483ba11ffd1fbb754361 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bea62892d3c61bc34f3d53f41733343ca6cae379 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de447254135efc8682df86162b1844eb1eb91496 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bbbe18bea6a7c264386071a06e0046e37efe773d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt @@ -0,0 +1,5 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVINVSCALE + xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convinvscale_instance ${GROUPED_CONV3D_FWD_CONVINVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bba72be7c682c6e32dae46a1f68a135c04d35183 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvInvscale = ck::tensor_operation::element_wise::ConvInvscale; + +void add_device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvInvscale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvInvscale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvInvscale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7f4a3527e66524cdd03ad908d6f7085d425f418 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -0,0 +1,8 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVSCALE + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convscale_instance ${GROUPED_CONV3D_FWD_CONVSCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e2c0eb1bd2765625ae3319ba3a83f7c0895f04d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector, + NDHWGK, + BF8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52cda2ea927a766d98bfc06c38fc42742d2321db --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector, + NDHWGK, + BF8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + BF8>>>& instances) +{ + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d63f58853d1d9efc275c3e54f0ac717790b6ca9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector, + NDHWGK, + F8, + BF8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + BF8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cfc99f9dcbf35d4560159cb33e067bf59fbda43e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScale = ck::tensor_operation::element_wise::ConvScale; + +void add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScale, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScale>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c60df5a73392af0ccc447d3c6a97239079d503e3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt @@ -0,0 +1,5 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD_CONVSCALE_RELU + xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_convscale_relu_instance ${GROUPED_CONV3D_FWD_CONVSCALE_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..472da0da78e9b657f6575620050c77318afa6915 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ConvScaleRelu = ck::tensor_operation::element_wise::ConvScaleRelu; + +void add_device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector, + NDHWGK, + F8, + F8, + ck::Tuple<>, + F8, + PassThrough, + PassThrough, + ConvScaleRelu, + F8, + F8>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwdDefault, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1P0, + ConvScaleRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_outelementop_f8_instances<3, + NDHWGC, + GKZYXC, + ck::Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + ConvScaleRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt index cbfcf8d221600b3af7a403cc6dbd8f71a4a1a926..0ba84c5cdc4bffd3362620ee80c9bbdbb3ce8565 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt @@ -5,8 +5,22 @@ set(GROUPED_GEMM_TILE_LOOP_INSTANCES) list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp - - device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp ) add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp index 505afbdff7b563cf50c1e07ac07b5c1819c68589..a41e6465b9f80a45001dec08cc1b3dff3b365334 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp @@ -38,16 +38,16 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_inst //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp index 9653d3eef04056ec333c6c9872b070124f2a4de4..32c3829d15edc47e5b8cb97e43859163483ebf63 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp @@ -37,19 +37,19 @@ using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_inst //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8>> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d943376a37ba878c688160f9e746d52441eef07a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = + std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6848774431e42c5a5afa0adf5e0bfd3e8a44472e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb2ea76aa4049271bac58cfa707a39a13ce7e9fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7439433f8a8ea3593ebc82c9bc6cc097585341e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b3afed0fd7da9850f0e47480c420c9d64d39d2fb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp index 0f62510a33f7e67b0bff6a4a7f7baca93f6d14bd..c98328e52d4c1fdfeef95044a94de0428caf73f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -31,51 +31,63 @@ using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastG using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple< -// clang-format off +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if 1 - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> -#endif -#if 0 - //comp - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, - - //latency - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - - //mem - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8> -#endif + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on >; +template +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = + std::tuple< + // clang-format off + //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances) { + // comp add_device_operation_instances( instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - Multiply>{}); -} + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmDefault>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNKPadding>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmMNPadding>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + Multiply, + GemmKPadding>{}); + // mem + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Intrawave>{}); -void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Row, - BF16, - I8, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>>>& instances) -{ add_device_operation_instances( instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - MultiplyAdd>{}); + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Interwave>{}); } void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6e5961cfeb90adf6a2eadc45e82363affbbc7b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0662bd5fe989e184ae11b51f16aaa9e29dd5d982 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb6781b7b2b0e888b8ad6c27dd68ebefc9b3562c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f2c07bf5d0c8046e31054fee3ce621a5b9cdcfc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v1_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd003f013988c74ba087bdcb69487b202c3bedac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmDefault, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d003178447fb2cd75d791a33aba4480429359418 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5810b8a3d01079af92a883a51ad8d6e666d2b1ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b4b37ed885cddaa31b13571afa8a8a7be23ad7a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_mnpadding_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + Multiply, + GemmMNPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13b0622e2ec4f88985c50c51d47d19dd1d0bfcdf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + MultiplyAdd>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + MultiplyAdd>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33696e281e18a75fcea57a0c0e5130e6ccf133a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + MultiplyAddFastGelu>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + MultiplyAddFastGelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6e72ac2d637932a5e4cef69a3b6de5937f5e3ed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + Row, + BF16, + I8, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + MultiplyFastGelu>{}); + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + MultiplyFastGelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/README.md b/profiler/README.md index a4daefba92187ac3b382c5531cb325f212a9aa59..10febcabdc4b9ab3a6925cbb2016815698c6addc 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -13,15 +13,6 @@ ./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -```bash -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -.... -Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s -``` - ## Profile 2D forward convolution kernels ```bash #arg1: tensor operation (conv=Convolution) @@ -37,15 +28,6 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ``` -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) - -```bash -in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -.... -Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s -``` ## Profile contraction kernels ```bash @@ -71,16 +53,6 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ./bin/ckProfiler contraction_bilinear 0 0 2 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ``` -Result (MI100) -```bash -a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384} -d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} -.... -Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s -``` - ## Profile batched gemm multiple D kernels ```bash #arg1: tensor operation (batched_gemm_multi_d=Batched GEMM multi D); @@ -99,14 +71,6 @@ Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s ./bin/ckProfiler batched_gemm_multi_d 0 1 0 0 0 1 4096 4096 4096 4096 4096 4096 16777216 16777216 16777216 16 ``` -Result (Radeon RX 6800 XT) -```bash -arg.a_grid_desc_k0_m0_m1_k1_{2048, 4096, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.e_grid_desc_m_n_{ 4096, 4096} -.... -Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s -``` ## Profile grouped convolution backward data kernels ```bash # arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) @@ -134,20 +98,6 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s ``` -Result (MI100, FP16, GNHWC_GKYXC_GNHWK) - -```bash -out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} -wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192} -in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} -.... -Best configuration parameters: -name: DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<256, 128, 256, 32, 8, 2, Default, 32, 32, 2, 4, 8, 4, 1, 1> -avg_time: 0.768321 -tflops: 86.6679 -GB/s: 127.947 -``` - ## Profile grouped convolution backward weight kernels ```bash # arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight) @@ -179,19 +129,6 @@ GB/s: 127.947 ``` -Result (MI100, FP16, GNHWC_GKYXC_GNHWK) - -```bash -input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024} -weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024} -output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512} -.... -Best configuration parameters: -name: DeviceGroupedConvBwdWeight_Xdl_CShuffle<256, 256, 128, 4, Default, 8, 4, 2, 8, 4, 8, 2, 1, 1, 8> -avg_time: 68.5216 -tflops: 95.337 -GB/s: 69.2301 -``` Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. ## Profile image to column/column to image kernels @@ -224,17 +161,6 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate ``` -Result (MI210, FP32, NHWC) - -```bash -input: dim 5, lengths {1, 256, 512, 28, 28}, strides {102760448, 401408, 1, 14336, 512} -output: dim 2, lengths {173056, 4608}, strides {4608, 1} -.... -Best configuration parameters: -name: DeviceImageToColumn<128, 32, 64, 4> -avg_time: 3.12326 -GB/s: 2042.59 -``` Note: Column to image kernel adds to the output memory, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. ## Profile Permute scale kernels @@ -254,12 +180,3 @@ Note: Column to image kernel adds to the output memory, this will cause output b ################ op datatype verify init log time dim0 dim1 dim2 in_stride0 in_stride1 in_stride2 out_stride0 out_stride1 out_stride2 ./bin/ckProfiler permute_scale 0 1 1 0 1 64 64 64 4096 64 1 1 64 4096 ``` - -Result (MI100, FP32) - -```bash -A: dim 3, lengths {64, 64, 64}, strides {4096, 64, 1} -B: dim 3, lengths {64, 64, 64}, strides {1, 64, 4096} -.... -Best perf = 0.0146878 ms, 142.782 GB/s, DeviceElementwiseNormalizationImpl<3, 2> -``` diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 362a5dccd11cd2408c77d593ed3de3e85e2d68b2..7fcadd7f7a79d6d37184915ef898665450b2aaaf 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -191,7 +191,24 @@ bool profile_gemm_universal_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif if(do_log) { @@ -230,25 +247,6 @@ bool profile_gemm_universal_impl(int do_verification, << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; -#if defined CK_ENABLE_FP8 - // set softer tolerances for fp8 - if constexpr(is_same_v || is_same_v || - is_same_v) - { - std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; - pass = pass & ck::utils::check_err( - c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); - } - else - { -#endif - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); -#if defined CK_ENABLE_FP8 - } -#endif - if(tflops > best_tflops) { best_op_name = op_name; diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..72194e8e6173e64a9111a8c35dab836be0da09db --- /dev/null +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_universal_streamk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int Streamk_sel, + int Grid_size, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemm_Streamk_V2; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_grid_size = 0; + float best_streamk_sel = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector grid_size_list = {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + std::vector streamk_sel_list = { + 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, + // 2:2-tile Stream-K + DP + + if(Grid_size == -1) + { + grid_size_list = {Grid_size}; + } + if(Streamk_sel != -1) + { + streamk_sel_list = {Streamk_sel}; + } + for(std::size_t j = 0; j < streamk_sel_list.size(); j++) + { + for(std::size_t i = 0; i < grid_size_list.size(); i++) + { + auto grid_size_curr = grid_size_list[i]; + index_t streamk_sel_curr = streamk_sel_list[j]; + printf("streamk_sel_curr=%0d\n", streamk_sel_curr); + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel_curr, + grid_size_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", Grid_size " + << grid_size_curr << ", streamk selection strategy" + << streamk_sel_curr << std::endl; + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_grid_size = grid_size_curr; + best_streamk_sel = streamk_sel_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC + << " Grid_size = " << best_grid_size + << " Stream-K selection strategy = " << best_streamk_sel << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 5b981dda338a567a9104fc895c75bb40bd4f9411..356aec7a087b410c28cfe93a18742f95452a9971 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -188,6 +188,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, out_element_op, split_k); + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // using atomic add, so need to reset input diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd756eb825f7c0cfc3c21b0cc5cc5fd1bd13a8d7 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -0,0 +1,352 @@ +#pragma once + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace profiler { + +template +inline constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + auto pass = true; // return status + + using CShuffleDataType = float; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = PassThrough; + using WeiElementOp = PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor c(out_g_n_k_wos_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + + // random scale values + auto scale_in = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_wei = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + auto scale_out = type_convert( + type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); + + // initialize out_element_op for each iteration + const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + + std::cout << "scale_in: " << scale_in << std::endl; + std::cout << "scale_wei: " << scale_wei << std::endl; + std::cout << "scale_out: " << scale_out << std::endl; + + // run reference op + if(do_verification) + { + + std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; + std::cout << "\tUsing (rel_tol,abs_tol) = (" << std::setprecision(7) + << get_rtol() << ", " << get_atol() << ")" << std::endl; + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 67fba43d64a05f0926d1500ad044011d99eeafd7..09e03de99c21e70671f30b13aa1b3ab6ef0066ae 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -88,11 +88,12 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 7f48ee0692149b95acae44d940d7babbcfa0df30..0b73e4fcd1fd1ad7810267b790fad5d84922e031 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -87,11 +87,12 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { diff --git a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f665644162614aa235c343b354356d5a23f2b299 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideEs, + int n_warmup = 10, + int n_iter = 50) +{ + using CDataType = EDataType; + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideEs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> d_m_n; + std::vector> e_m_n_host_results; + std::vector> e_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + d_m_n.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], DLayout{}))); + e_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + e_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i + << "]:" << e_m_n_device_results[i].mDesc << std::endl; + } + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(b_k_n[i]); + ck::utils::FillUniformDistributionIntegerValue{-5, 5}(d_m_n[i]); + break; + case 2: + ck::utils::FillUniformDistribution{.0, 1.}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(b_k_n[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(d_m_n[i]); + break; + default: + ck::utils::FillConstant{1}(a_m_k[i]); + ck::utils::FillConstant{1}(b_k_n[i]); + ck::utils::FillConstant{1}(d_m_n[i]); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementOp = ck::tensor_operation::element_wise::Multiply; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, d_device_buf, e_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + d_device_buf.reserve(group_count); + e_device_buf.reserve(group_count); + + std::vector p_a, p_b, p_d; + constexpr ck::index_t NumDTensor = 1; + auto p_ds = std::vector>{}; + std::vector p_e; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_ds.reserve(group_count); + p_e.reserve(group_count); + + using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + + std::vector gemm_descs; + std::vector gemm_kargs; + + gemm_descs.reserve(group_count); + gemm_kargs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + d_device_buf.emplace_back( + std::make_unique(sizeof(DDataType) * d_m_n[i].mDesc.GetElementSpaceSize())); + e_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + d_device_buf[i]->ToDevice(d_m_n[i].mData.data()); + e_device_buf[i]->SetZero(); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_ds.push_back({d_device_buf[i]->GetDeviceBuffer()}); + p_e.push_back(e_device_buf[i]->GetDeviceBuffer()); + + gemm_descs.push_back( + {0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}}); + gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(), + b_device_buf[i]->GetDeviceBuffer(), + {d_device_buf[i]->GetDeviceBuffer()}, + e_device_buf[i]->GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {StrideDs[i]}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + Tensor c_m_n({Ms[i], Ns[i]}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument( + a_m_k[i], b_k_n[i], c_m_n, a_element_op, b_element_op, c_element_op); + ref_invoker.Run(ref_argument); + + for(int m = 0; m < Ms[i]; ++m) + { + for(int n = 0; n < Ns[i]; ++n) + { + cde_element_op(e_m_n_host_results[i](m, n), c_m_n(m, n), d_m_n[i](m, n)); + } + } + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + cde_element_op); + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + std::string gemm_name = gemm_ptr->GetTypeString(); + + DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + gemm_kargs.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data()); + instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i], + e_m_n_host_results[i]); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "e_device: ", e_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_host : ", e_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + if(time_kernel) + { + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(EDataType) * Ms[i] * Ns[i] + // D matrix + sizeof(EDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 3d7fa470770059568904d8c46a51a9ee7a59ff11..74faf15be3e86cd655ec933025aaaffefa0ef5c3 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -82,11 +82,12 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } switch(init_method) { case 0: break; diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp index 41dcabbfcf808cd6cf82e5a9be36b50aac7e1221..14df96d5057b137b8104107e21a846749715f670 100644 --- a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -88,11 +88,12 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); -#if DEBUG_LOG - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i - << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; -#endif // DEBUG_LOG + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" + << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + } std::size_t num_thread = 1; switch(init_method) { diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt old mode 100644 new mode 100755 index 1cfcbfff647164dba3df50c8269d0f1a61ea907d..198f49432f7f8c49fe3e10b217c586eed35bec97 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -43,6 +43,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) @@ -51,14 +52,16 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) endif() -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() @@ -119,6 +122,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) @@ -131,9 +135,11 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) endif() -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cd3f5787d691fdf2e4206c9b2c552d66b20ba114 --- /dev/null +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_universal_streamk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 + F16_F16_F16_F8, // 6 +}; + +#define OP_NAME "gemm_universal_streamk" +#define OP_DESC "Universal Streamk GEMM" + +int profile_gemm_universal_streamk(int argc, char* argv[]) +{ + if(argc != 16 && argc != 19) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " + "comp f8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: Stream-k select strategy 0: all DP, 1: 1-tile SK, 2: 2-tile SK\n"); + printf("arg15: Grid-size, -1 for max persistent kernel occupancy\n"); + printf("optional:\n"); + printf("arg16: number of warm-up cycles (default 1)\n"); + printf("arg17: number of iterations (default 10)\n"); + printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int Streamk_sel = std::stoi(argv[14]); + const int Grid_size = std::stoi(argv[15]); + + int n_warmup = 20; + int n_iter = 50; + uint64_t rotating = 0; + if(argc == 19) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + rotating = std::stoull(argv[18]) * 1024 * 1024; + } + + using F32 = float; + using F16 = ck::half_t; + // using BF16 = ck::bhalf_t; + // using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_universal_streamk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + Streamk_sel, + Grid_size, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal_streamk); diff --git a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..196a2cf3f2188d9eab8402e6b3c794b575eb1c0b --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp @@ -0,0 +1,220 @@ +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK = 0, + NHWGC_GKYXC_NHWGK = 1 +}; + +enum struct OutElementOp +{ + ConvScale = 0, + ConvInvScale = 1 +}; + +enum struct ConvDataType +{ + F8_F8_F8 = 0, + BF8_BF8_F8 = 1, + F8_BF8_F8 = 2, + BF8_F8_F8 = 3 +}; + +#define OP_NAME "grouped_conv_fwd_outelementop" +#define OP_DESC "Grouped Convolution Forward+Elementwise Operation" + +static void print_helper_msg() +{ + // clang-format off + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp8, Weight fp8, Output fp8\n" + << " 1: Input bf8, Weight bf8, Output fp8\n" + << " 2: Input fp8, Weight bf8, Output fp8\n" + << " 3: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: element-wise operation (0: ConvScale\n" + << " 1: ConvInvScale)\n" + << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_outelementop(int argc, char* argv[]) +{ + + // 9 total, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto op = static_cast(std::stoi(argv[3])); + const auto layout = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + 1 for argv[0] + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using ConvScale = ck::tensor_operation::element_wise::ConvScale; + using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto out_element_op, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using OutElementOp = decltype(out_element_op); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_outelementop_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(op == OutElementOp::ConvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, ConvScale{}, F8{}, BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, ConvScale{}, BF8{}, F8{}); + } + } + else if(op == OutElementOp::ConvInvScale) + { + if(data_type == ConvDataType::F8_F8_F8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, ConvInvScale{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + BF8{}, + F8{}, + ConvInvScale{}, + BF8{}, + BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + F8{}, + BF8{}, + F8{}, + ConvInvScale{}, + F8{}, + BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF8{}, + F8{}, + F8{}, + ConvInvScale{}, + BF8{}, + F8{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_outelementop); diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 25203d7b6c55d36860293b798ea13cd0739c9b75..fbf44d720f1eea5a21857408fd1893a4fe28139d 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -98,8 +98,8 @@ int profile_grouped_gemm(int argc, char* argv[]) int n_iter = 10; if(argc == 17) { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); } #ifdef CK_ENABLE_FP16 diff --git a/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp b/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5cf0af5ecfb42e094195c0dbc8748669d1fce50f --- /dev/null +++ b/profiler/src/profile_grouped_gemm_multiply_tile_loop.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 +}; + +enum struct GemmDataType +{ + BF16_INT8_BF16_BF16, // 0 +}; + +#define OP_NAME "grouped_gemm_multiply_tile_loop" +#define OP_DESC "Grouped GEMM Multiply Multiple D Tile Loop" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + std::istringstream in(input); + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + return out; +} + +int profile_grouped_gemm_tile_loop(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: bf16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "optional:\n" + << "arg14: number of warm-up cycles (default 1)\n" + << "arg15: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + std::vector StrideDs(StrideCs); + + int n_warmup = 10; + int n_iter = 50; + if(argc == 16) + { + n_warmup = std::stoi(argv[14]); + n_iter = std::stoi(argv[15]); + } + + if(data_type == GemmDataType::BF16_INT8_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_multiply_tile_loop_impl< + ck::bhalf_t, + int8_t, + ck::bhalf_t, + ck::bhalf_t, + float, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor, + ck::tensor_layout::gemm::RowMajor>(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideCs, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop); diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..8e7e8607baaed0eb94329a04018fe729eb8f7904 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "rocm-composable-kernel" +dynamic = ["version"] +description = "Composable Kernel, performance-critical kernels for machine learning workloads" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/rocm/composable_kernel" +"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" + +[tool.setuptools] +packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"] + +[tool.setuptools.package-dir] +ck4inductor = "python/ck4inductor" +"ck4inductor.include" = "include" +"ck4inductor.library" = "library" + +[tool.setuptools.package-data] +"ck4inductor.include" = ["ck/**/*.hpp"] +"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"] + +[tool.setuptools.dynamic] +version = { attr = "setuptools_scm.get_version" } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6d6b73b2fca14709fa7cbefeaf3afe676b24f1 --- /dev/null +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -0,0 +1,570 @@ +import logging +import os +import subprocess +from dataclasses import fields, replace +from functools import lru_cache, partial +from typing import List + +from ..util import library_path + +from .op import CKGemmOperation + +log = logging.getLogger(__name__) + + +def _ck_library_dir(): + gemm_instances_path = os.path.join( + library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal" + ) + if not os.path.exists(gemm_instances_path): + log.error("CK library path %s does not exist", gemm_instances_path) + return None + return gemm_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: + """ + Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + for line in str_instances: + s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + # pad with `None`s for the fields which are not defined in the instance + new_instance = CKGemmOperation( + *template_args, # type: ignore[arg-type] + *((None,) * (len(fields(CKGemmOperation)) - len(template_args))), + ) + # the last 2 template parameters are optional + # if they are absent, substitute them with default values from Universal Gemm C++ template declaration + if new_instance.a_compute_dtype is None: + new_instance.a_compute_dtype = new_instance.c_element_dtype + if new_instance.b_compute_dtype is None: + new_instance.b_compute_dtype = new_instance.c_element_dtype + + op_instances.append(new_instance) + return op_instances + + +def default_instances() -> List[CKGemmOperation]: + # fallback: known working op instance for problem size M=2240 K=256 N=2048 + # all string attributes must be either type aliases or global constants in C++ + + return [ + CKGemmOperation( + a_layout="Row", + b_layout="Row", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + a_compute_dtype="F16", + b_compute_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + gemm_specialization="GemmSpecialization::Default", + block_size=256, + m_per_block=224, + n_per_block=256, + k_per_block=64, + a_k1=8, + b_k1=2, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, # type: ignore[arg-type] + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + b_block_transfer_thread_cluster_arrange_order=(0, 2, 1), + b_block_transfer_src_access_order=(0, 2, 1), + b_block_transfer_src_vector_dim=1, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=2, + b_block_lds_extra_n=0, # type: ignore[arg-type] + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ) + ] + + +@lru_cache(None) +def gen_ops_library() -> List[CKGemmOperation]: + """ + Parse the Universal Gemm instances defined in the composable kernel library folder. + """ + ck_library_dir = _ck_library_dir() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceGemm_Xdl_CShuffleV3", + _ck_library_dir(), + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + gemm_specs = [ + "GemmSpecialization::Default", + "GemmSpecialization::MPadding", + "GemmSpecialization::NPadding", + "GemmSpecialization::KPadding", + "GemmSpecialization::MNPadding", + "GemmSpecialization::MKPadding", + "GemmSpecialization::NKPadding", + "GemmSpecialization::MNKPadding", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + sub_spec = instance.gemm_specialization == "GemmSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + gemm_specialization=spec, + ) + ) + + return substitute_instances + + +@lru_cache(None) +def gen_ops_preselected() -> List[CKGemmOperation]: + """ + Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances + """ + ck_gemm_f16_rcr = partial( + CKGemmOperation, + a_layout="Row", + b_layout="Col", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + k_per_block=64, + a_k1=8, + b_k1=8, + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, + b_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + b_block_transfer_src_access_order=(1, 0, 2), + b_block_transfer_src_vector_dim=2, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=8, + b_block_lds_extra_n=0, + a_compute_dtype="F16", + b_compute_dtype="F16", + ) + ck_gemm_f16_rcr_compute_friendly = partial( + ck_gemm_f16_rcr, + block_size=256, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ) + ck_gemm_f16_rcr_memory_friendly = partial( + ck_gemm_f16_rcr, + block_size=128, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v2", + ) + ck_gemm_f16_rcr_latency_friendly = partial( + ck_gemm_f16_rcr, + gemm_specialization="GemmSpecialization::Default", + block_size=128, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v1", + ) + return [ + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=224, + n_per_block=256, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=64, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=64, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 64, + 1, + 2, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=16, + n_per_block=32, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=32, + n_per_block=16, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + ), + ] + + +if __name__ == "__main__": + print(gen_ops_library()) diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py new file mode 100644 index 0000000000000000000000000000000000000000..ab541c5fb95e2a3927918f602ff80cf7cb0e41e0 --- /dev/null +++ b/python/ck4inductor/universal_gemm/op.py @@ -0,0 +1,95 @@ +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKGemmOperation: + """ + A python dataclass storing the template parameters of a CK Universal Gemm template instance + """ + + a_layout: str + b_layout: str + c_layout: str + + a_element_dtype: str + b_element_dtype: str + c_element_dtype: str + + acc_dtype: str + c_shuffle_dtype: str + + a_elementwise_op: str + b_elementwise_op: str + c_elementwise_op: str + + gemm_specialization: str + + block_size: int + + m_per_block: int + n_per_block: int + k_per_block: int + + a_k1: int + b_k1: int + + m_per_xdl: int + n_per_xdl: int + + m_xdl_per_wave: int + n_xdl_per_wave: int + + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( + Tuple[int, int, int, int] + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block: int + + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: Optional[str] + + a_compute_dtype: Optional[str] + b_compute_dtype: Optional[str] + + def name(self): + # cpp alias for template instance + return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/util.py b/python/ck4inductor/util.py new file mode 100644 index 0000000000000000000000000000000000000000..79d6be00f38b8397bf320b23c54df118c581816d --- /dev/null +++ b/python/ck4inductor/util.py @@ -0,0 +1,7 @@ +import functools +import os + + +@functools.lru_cache(None) +def library_path(): + return os.path.join(os.path.dirname(__file__), 'library') diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh old mode 100755 new mode 100644 diff --git a/script/profile_grouped_conv_fwd_outelementop.sh b/script/profile_grouped_conv_fwd_outelementop.sh new file mode 100755 index 0000000000000000000000000000000000000000..ac444a25c2b5df71a7b96fb354424a6a6cbc1c05 --- /dev/null +++ b/script/profile_grouped_conv_fwd_outelementop.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" + +OP=$1 +DATATYPE=$2 +OUTELEMENTOP=$3 +LAYOUT=$4 +VERIFY=$5 +INIT=$6 +LOG=$7 +TIME=$8 + +N=$9 + +####### op datatype OUTELEMENTOP layout verify init log time Ndims G N K C Z Y X Di Hi Wi Sz Sy Sx Dz Dy Dx Left Pz LeftPy LeftPx RightPz RightPy RightPx +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 96 96 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1 +$DRIVER $OP $DATATYPE $OUTELEMENTOP $LAYOUT $VERIFY $INIT $LOG $TIME 3 32 $N 192 192 3 3 3 28 28 28 1 1 1 1 1 1 1 1 1 1 1 1 diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh index 1bd7a6b5d71dfdc185b287a47fbda2dd1b7b4d86..8bd2c2fc33898167f6e62d87c55239c9e381e7ee 100644 --- a/script/test_convnd_fwd.sh +++ b/script/test_convnd_fwd.sh @@ -65,7 +65,7 @@ set -- "${POSITIONAL[@]}" # restore positional parameters # NUMACTL="numactl --cpunodebind=1 --membind=1" NUMACTL= # ENV_CONF= -GPU=mi100 +GPU=gfx908 PROF_ITER_COUNT=10000 LOG_DIR_PATH=../log/${LOG_DIR} set -x diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 33aa10df72baf6512884b6d390e313c3c046b908..fc1bcfdb272c61f1461add42d6019c50c285e763 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -47,20 +54,29 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) add_dependencies(tests ${TEST_NAME}) @@ -105,6 +121,13 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -112,20 +135,29 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + elseif(ARGN MATCHES "_smfmac") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) @@ -181,3 +213,7 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() +if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 + add_subdirectory(smfmac_op) +endif() +add_subdirectory(position_embedding) diff --git a/test/contraction/test_contraction_xdl.cpp b/test/contraction/test_contraction_xdl.cpp index c84375b1db4fd8642eaa450e8c56a8b87208f685..2bfd5a6a66aed337de5043c4af2415a0114339ac 100644 --- a/test/contraction/test_contraction_xdl.cpp +++ b/test/contraction/test_contraction_xdl.cpp @@ -212,4 +212,10 @@ TYPED_TEST(TestContractionScaleMixedPrecision, scale) this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + + // special cases + this->template Run<2>({{1, 1}, {16, 8}, {8, 16}}); + this->template Run<2>({{8, 16}, {16, 8}, {1, 1}}); + this->template Run<2>({{8, 16}, {1, 1}, {8, 16}}); + this->template Run<2>({{1, 1}, {1, 1}, {1, 1}}); } diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 3507989bae26bce6dad3e01d283306d51efbee65..8edb7152003b43f2fe588b1c8abc2bbf25479c16 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_data_interface_wmma test_grouped_convnd_bwd_data_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) + target_link_libraries(test_grouped_convnd_bwd_data_interface_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance) endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp index c0429c6d093688916ce079fe48271a0abb6bd299..fbb6ffc6f569e9b8301972a5a1da32ae2618c9c8 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 54b514e7a149bcb2a5cc5d18eb55dbac6a9690c4..313b5ba4caf6dcc0803587fa8d923fb20349d2fd 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -5,13 +5,13 @@ if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS) add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility) endif() -add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp) if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) + target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility) endif() add_gtest_executable(test_grouped_conv_bwd_weight_xdl_bilinear test_grouped_conv_bwd_weight_xdl_bilinear.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp index d733325a98308c7155406186cbab624263d9fa49..11748d4717c300baeda2c7706c123e26b1d598ce 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp @@ -264,5 +264,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index d100fb10778ca9811a051d7a23ea4769bb44138f..aee80cb2cbfe9afb44e2b65ec95ac54ba9cf8532 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test std::vector conv_params; std::vector split_ks{1, 2}; - bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) + bool skip_case(const ck::index_t split_k) { - // Odd K or C values are supported only by DL and WMMA - // kernels (only applies to fp16) - // DL and WMMA kernels currently support only `split_k=1` - if constexpr(std::is_same_v) - { - if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) - { - return true; - } - } - // 1d NWGC is only supported by DL kernel // DL kernel is only supported for split_k=1 if constexpr(std::is_same_v && std::is_same_v) @@ -55,14 +44,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - if(ck::is_navi3_supported()) + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { - // on navi3x only support for 3d is implemented + // on gfx11 only support for 3d is implemented if constexpr(NDimSpatial{} != 3) { return true; } - // on navi3x only support for i8 and fp16 is implemented + // on gfx11 only support for i8 and fp16 is implemented if constexpr(!((std::is_same_v && std::is_same_v && std::is_same_v) || @@ -80,7 +69,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } else { - // support for i8 is only implemented on navi3x + // support for i8 is only implemented on gfx11 if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { @@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test { for(auto& param : conv_params) { - if(!skip_case(param, split_k)) + if(!skip_case(split_k)) { pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_implconv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); this->Run(); } @@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp index 1dcb8f866d161ab18e71d11db12da3f93088cf2b..2e2f5332ae74b98c4f21bbdc335ae54265e5366e 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_wmma.cpp @@ -52,6 +52,14 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ck::utils::conv::ConvParam conv_param; + void SetUp() override + { + if(!ck::is_gfx11_supported()) + { + GTEST_SKIP(); + } + } + template bool Run() { diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 4f245d63cdf539e85b19aa75b5bfba493bcde13c..f611e66243c22d7c8ab80040603cbd9e7811d8e8 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,6 +1,10 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) + if((GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + else() + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + endif() endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp similarity index 84% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index dde8313f944eb4f1f7296e92a1551ef52c6e4eef..1bfc1831353d9468579c47235f7e8d0958d76953 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -69,6 +69,8 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple>; +using KernelTypes2dLargeCases = ::testing::Types>; + template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd { @@ -84,9 +86,15 @@ class TestGroupedConvndFwd3d : public TestGroupedConvndFwd { }; +template +class TestGroupedConvndFwd2dLargeCases : public TestGroupedConvndFwd +{ +}; + TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); +TYPED_TEST_SUITE(TestGroupedConvndFwd2dLargeCases, KernelTypes2dLargeCases); TYPED_TEST(TestGroupedConvndFwd1d, Test1D) { @@ -96,6 +104,7 @@ TYPED_TEST(TestGroupedConvndFwd1d, Test1D) this->conv_params.push_back({1, 2, 32, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 96, 1, 1, 1, {3}, {512}, {1}, {1}, {1}, {1}}); this->template Run<1>(); } @@ -111,6 +120,8 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 96, 1, 1, 1, {3, 3}, {120, 160}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } @@ -129,5 +140,18 @@ TYPED_TEST(TestGroupedConvndFwd3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {4, 30, 160}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } + +TYPED_TEST(TestGroupedConvndFwd2dLargeCases, Test2DLargeCases) +{ + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 64, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp index c529a6a61bbc52afa7515bac86a43cf766b32fa4..346f04f66d5a60f71351e89744eb8d8a1f7de8f5 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_ab_interface.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -207,7 +207,7 @@ TEST_F(TestGroupedConvndFwdMultiAInterface, MultiA) std::array as{nullptr, nullptr}; const void* b = nullptr; - EXPECT_TRUE(this->template Run(as, b)); + EXPECT_TRUE(this->Run(as, b)); } TEST_F(TestGroupedConvndFwdMultiBInterface, MultiB) @@ -215,7 +215,7 @@ TEST_F(TestGroupedConvndFwdMultiBInterface, MultiB) const void* a = nullptr; std::array bs{nullptr, nullptr}; - EXPECT_TRUE(this->template Run(a, bs)); + EXPECT_TRUE(this->Run(a, bs)); } TEST_F(TestGroupedConvndFwdMultiABInterface, MultiAB) @@ -223,7 +223,7 @@ TEST_F(TestGroupedConvndFwdMultiABInterface, MultiAB) std::array as{nullptr, nullptr}; std::array bs{nullptr, nullptr}; - EXPECT_TRUE(this->template Run(as, bs)); + EXPECT_TRUE(this->Run(as, bs)); } TEST_F(TestGroupedConvndFwdInterface, SingleAB) @@ -231,5 +231,5 @@ TEST_F(TestGroupedConvndFwdInterface, SingleAB) const void* a = nullptr; const void* b = nullptr; - EXPECT_TRUE(this->template Run(a, b)); + EXPECT_TRUE(this->Run(a, b)); } diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f47685cf91cf40dd26d452aa09336694d473d0ec..55cb209772b0cc140c739f76a1cca3aac0fd6e71 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -6,6 +6,12 @@ if(result EQUAL 0) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() +add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk) +endif() + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67ecbaea303c08bce68374a1f6469a014624682f --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RRR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN, + RRR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK, + RCR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16, + RRR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16, + RCR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8, + RRR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8, + RCR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" +#include "test_grouped_gemm_two_stage_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc new file mode 100644 index 0000000000000000000000000000000000000000..40d48f4ec0f1e96e5137362c7eee8477567667d7 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +TEST_P(RRR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 50f423ada399ef96ba443ac296e064ae1e083f8c..9e1395b9f8d22f9511efbb42c2f1f3b3e13fe025 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" namespace ck { namespace test { @@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam } }; +template +class TestGroupedGemmTwoStage : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + template +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +#ifndef TEST_ALIBI_VERBOSE +#define TEST_ALIBI_VERBOSE 0 +#endif + +template +struct attention_score +{ + ck_tile::index_t rows, cols; + std::vector pixels; + + attention_score(ck_tile::index_t rows_, + ck_tile::index_t cols_, + DataType init_v_ = static_cast(0)) + : rows(rows_), cols(cols_), pixels(rows_ * cols_, init_v_) + { + } + + auto& operator()(ck_tile::index_t i_row, ck_tile::index_t i_col) + { + return pixels[i_row * cols + i_col]; + } + + void print() + { + for(auto i_row = 0; i_row < rows; i_row++) + { + for(auto i_col = 0; i_col < cols; i_col++) + { + std::cout << pixels[i_row * cols + i_col] << " "; + } + std::cout << std::endl; + } + } +}; + +template +void alibi_traverse_with_slope(attention_score& score, + DataType slope, + ck_tile::AlibiMode mode = ck_tile::AlibiMode::VERTICAL) +{ + using Alibi = ck_tile::Alibi; + auto alibi = Alibi{slope, score.rows, score.cols, mode}; + + for(ck_tile::index_t i_row = 0; i_row < score.rows; i_row++) + { + for(ck_tile::index_t i_col = 0; i_col < score.cols; i_col++) + { + alibi.update(score(i_row, i_col), i_row, i_col); + } + } +} + +std::string alibi_mode_to_str(ck_tile::AlibiMode mode) +{ + if(mode == ck_tile::AlibiMode::VERTICAL) + return std::string("alibi_verti"); + else if(mode == ck_tile::AlibiMode::FROM_TOP_LEFT) + return std::string("alibi_top-l"); + else if(mode == ck_tile::AlibiMode::FROM_BOTTOM_RIGHT) + return std::string("alibi_bot-r"); + return ""; +} + +template +bool test_alibi_traverse_with_slope(ck_tile::index_t rows, + ck_tile::index_t cols, + DataType slope, + ck_tile::AlibiMode mode, + const std::vector& expected) +{ + attention_score score{rows, cols}; + alibi_traverse_with_slope(score, slope, mode); + + bool is_match = std::equal(score.pixels.begin(), score.pixels.end(), expected.begin()); +#if TEST_ALIBI_VERBOSE + std::cout << "---------" << alibi_mode_to_str(mode) << ", " << rows << "x" << cols << "(" + << (RowMajor ? "row_major" : "col_major") << ")" + << (is_match ? ", valie:y" : ", valid:n") << std::endl; + score.print(); +#endif + return is_match; +} + +template +bool test_alibi_slope_generation(ck_tile::index_t nheads, const std::vector& expected) +{ + auto slopes = ck_tile::get_alibi_slopes(nheads); + + bool is_match = std::equal(slopes.begin(), + slopes.end(), + expected.begin(), + expected.end(), + [](const DataType& lhs, const DataType& rhs) { + constexpr float rtol = 1e-6; + auto error = std::abs(lhs - rhs); + return error < rtol * std::abs(rhs); + }); +#if TEST_ALIBI_VERBOSE + std::cout << "-------------------- slopes " << nheads << ", " << (is_match ? "y" : "n") + << std::endl; + for(ck_tile::index_t i = 0; i < nheads; i++) + { + std::cout << slopes[i] << " "; + } + std::cout << std::endl; +#endif + return is_match; +} + +int main() +{ + using dtype = int32_t; + dtype slope = static_cast(1); + + bool rtn = true; + + // clang-format off + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); + + rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); + rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, + 0.12500000000000006, 0.08838834764831849, 0.06250000000000004, 0.044194173824159244, + 0.03125000000000002, 0.022097086912079626, 0.01562500000000001, 0.011048543456039816, + 0.007812500000000007, 0.005524271728019908, 0.003906250000000004}); + rtn &= test_alibi_slope_generation(1, {0.00390625}); + rtn &= test_alibi_slope_generation(5, {0.25, 0.0625, 0.015625, 0.00390625, 0.5}); + rtn &= test_alibi_slope_generation(6, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125}); + rtn &= test_alibi_slope_generation(7, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125, 0.03125}); + rtn &= test_alibi_slope_generation(9, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625, 0.7071067811865476}); + // clang-format on + return rtn ? 0 : -1; +} diff --git a/test/smfmac_op/CMakeLists.txt b/test/smfmac_op/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4ffc423f541213bd68150cf4084168c6240e9c10 --- /dev/null +++ b/test/smfmac_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_smfmac_op smfmac_op_xdl.cpp) +target_link_libraries(test_smfmac_op PRIVATE utility) diff --git a/test/smfmac_op/smfmac_op.cpp b/test/smfmac_op/smfmac_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de4f9414af3400f9b75c6b002820a3752c98ac53 --- /dev/null +++ b/test/smfmac_op/smfmac_op.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +template +bool run_test() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + bool pass = true; + + const auto matmul_default = ck::smfmac_op_util::matmul; + + const auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, 1, 1>{}([&](auto i) { + pass &= + ck::smfmac_op_util::TestSmfmac{}>( + smfmac_kernel_container)), + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + return pass; +} +int main(int, char*[]) +{ + bool pass = true; + // clang-format off + // | Src1Type| Src1VecSize| Src2Type| Src2VecSize| DstType| DstVecSize| GPUAccType| CPUAccType| M| N| K| + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 4, float, float,16,16,32>(); + pass &= run_test(); + pass &= run_test< ck::half_t, 4, ck::half_t, 8, float, 16, float, float,32,32,16>(); + pass &= run_test(); + // clang-format on + + std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + return pass; +} diff --git a/test/smfmac_op/smfmac_op_util.hpp b/test/smfmac_op/smfmac_op_util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..44122c551d2e9b38274c2fc882baf4b7c66caf3e --- /dev/null +++ b/test/smfmac_op/smfmac_op_util.hpp @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/amd_smfmac.hpp" +#include "ck/library/utility/fill.hpp" + +namespace ck { +namespace smfmac_op_util { + +template +__device__ void +builtin_smfmac_naive_selector(const src_vec1&, const src_vec2&, const int32_t&, acc_vec&) +{ +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32f16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void +builtin_smfmac_naive_selector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_16x16x32bf16<16, 16>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + half4_t, + half8_t, + StaticBufferTupleOfVector>( + const half4_t& reg_a, + const half8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16f16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +template <> +__device__ void builtin_smfmac_naive_selector< + bhalf4_t, + bhalf8_t, + StaticBufferTupleOfVector>( + const bhalf4_t& reg_a, + const bhalf8_t& reg_b, + const int32_t& reg_idx, + StaticBufferTupleOfVector& reg_c) +{ + intrin_smfmac_f32_32x32x16bf16<32, 32>::Run( + reg_a, reg_b, reg_idx, reg_c.GetVectorTypeReference(Number<0>{})); +} + +// Smfmac instructions are using 4:2 structural sparsity, that means that in every contignuous +// subgroup of 4 elements, atleast 2 must be equal to zero and the position of non-zero elements is +// stored in idx register to allow selection of corresponding B matrix elements for multiplication. +// Currently smfmac instructions support only A matrix as sparse +template +__global__ void matmul(const src1_t* a, const src2_t* b, dst_t* c) +{ + __shared__ src1_t a_shared[M * K]; + __shared__ src2_t b_shared[K * N]; + const int lane = threadIdx.x; + // smfmac's A part is storing only non-zero elements in 2VGPRs + // smfmac's B part is storing all elements in 4VGPRs + using src1_vec = typename vector_type::type; + using src1_full_vec = typename vector_type::type; + using src2_vec = typename vector_type::type; + src1_vec a_frag = {}; + src2_vec b_frag = {}; + + src1_full_vec a_temp = {}; + src2_vec b_temp = {}; + // initialize c fragment to 0 + using acc_vec = StaticBufferTupleOfVector; + acc_vec c_thread_buf_; + + for(int i = 0; i < 8; ++i) + { + a_temp[i] = a[(lane % M) * K + (lane / M) * 8 + i]; // M K + } + + for(int i = 0; i < 8; ++i) + { + b_temp[i] = b[(8 * (lane / N) + i) * N + (lane % N)]; // K N + } + + __syncthreads(); + + for(int i = 0; i < 8; ++i) + { + a_shared[(lane % M) * K + (lane / M) * 8 + i] = a_temp[i]; + } + for(int i = 0; i < 8; ++i) + { + b_shared[(8 * (lane / N) + i) * N + (lane % N)] = b_temp[i]; + } + + __syncthreads(); + + // Idx must be a 32-bit register and it is storing 4 2-bit indexes of A's non zero elements. + // It starts with last two elements of every 4 elements subgroup set as non-zero + int32_t idx = 0b11101110; + // Bit masks are for zeroing 0-3rd position of idx + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + + src1_t curr_val; + int32_t a_pos = 0; + for(int j = 0; j < 2; ++j) + { + a_pos = j * 2; + for(int i = 0; i < 4; ++i) + { + curr_val = a_shared[(lane % M) * K + (lane / M) * 8 + 4 * j + i]; + if(curr_val != 0.0f) + { + idx &= ~bit_clear_masks[a_pos]; + idx |= (i % 4) << 2 * a_pos; + a_frag[a_pos] = curr_val; + a_pos++; + } + } + } + + for(int i = 0; i < 8; ++i) + { + b_frag[i] = b_shared[(8 * (lane / N) + i) * N + (lane % N)]; + } + + builtin_smfmac_naive_selector(a_frag, b_frag, idx, c_thread_buf_); + __syncthreads(); + + // store results from unpacked c_thread_buf_ output + if constexpr(K == 32) + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[(4 * (lane / 16) + i) * N + lane % 16] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } + else + { + static_for<0, acc_vec_size, 1>{}([&](auto i) { + c[((8 * (i / 4)) % 32 + 4 * (lane / 32) + i % 4) * N + lane % 32] = + ck::type_convert(c_thread_buf_[Number{}]); + }); + } +} + +struct GemmParams +{ + GemmParams() : M(16), N(16), K(32), StrideA(32), StrideB(16), StrideC(16), alpha(1), beta(0) {} + + ck::index_t M; + ck::index_t N; + ck::index_t K; + + ck::index_t StrideA; + ck::index_t StrideB; + ck::index_t StrideC; + + float alpha; + float beta; +}; + +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestSmfmac +{ + auto PrepareGemmTensor(const ck::smfmac_op_util::GemmParams& params) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + auto f_generate_tensor_value = [](auto& tensor, auto type) { + using dataType = decltype(type); + tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }; + + f_generate_tensor_value(a_m_k, ADataType{}); + f_generate_tensor_value(b_n_k, BDataType{}); + ck::utils::TransformIntoStructuralSparsity{}(a_m_k); + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceSmfmac& smfmac_kernel) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + ck::smfmac_op_util::GemmParams params; + params.M = M; + params.N = N; + params.K = K; + params.StrideA = K; // M K + params.StrideB = N; // K N + params.StrideC = N; // M N + + auto host_tensors = PrepareGemmTensor(params); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + auto a_element_op = AElementwiseOperation{}; + auto b_element_op = BElementwiseOperation{}; + auto c_element_op = CElementwiseOperation{}; + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + ck::smfmac_op_util::RunHostGEMM( + a, b, c_host, a_element_op, b_element_op, c_element_op); + + // Act + bool is_supported = ck::smfmac_op_util::RunDeviceGEMM(smfmac_kernel, a, b, c_device); + + if(is_supported) + { + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } + else + { + return true; + } + } +}; + +} // namespace smfmac_op_util +} // namespace ck diff --git a/test/smfmac_op/smfmac_op_xdl.cpp b/test/smfmac_op/smfmac_op_xdl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..292fd259ea4258ef6cff3b7d4ed1ff5e9ae14045 --- /dev/null +++ b/test/smfmac_op/smfmac_op_xdl.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "test/smfmac_op/smfmac_op_util.hpp" + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestSmfmac : public ::testing::Test +{ + protected: + using Src1Type = std::tuple_element_t<0, Tuple>; + static constexpr ck::index_t Src1VecSize = std::tuple_element_t<1, Tuple>{}.value; + using Src2Type = std::tuple_element_t<2, Tuple>; + static constexpr ck::index_t Src2VecSize = std::tuple_element_t<3, Tuple>{}.value; + using DstType = std::tuple_element_t<4, Tuple>; + static constexpr ck::index_t AccVecSize = std::tuple_element_t<5, Tuple>{}.value; + using GPUAccType = std::tuple_element_t<6, Tuple>; + using CPUAccType = std::tuple_element_t<7, Tuple>; + static constexpr ck::index_t M = std::tuple_element_t<8, Tuple>{}.value; + static constexpr ck::index_t N = std::tuple_element_t<9, Tuple>{}.value; + static constexpr ck::index_t K = std::tuple_element_t<10, Tuple>{}.value; + + void Run() + { + bool pass = true; + constexpr auto matmul_default = ck::smfmac_op_util::matmul; + + constexpr auto smfmac_kernel_container = std::make_tuple(matmul_default); + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { + pass &= ck::smfmac_op_util::TestSmfmac< + std::tuple_element_t, + Src1Type, + Src2Type, + DstType, + GPUAccType, + CPUAccType, + decltype(Row{}), + decltype(Row{}), + decltype(Row{}), + PassThrough, + PassThrough, + PassThrough, + AccVecSize, + M, + N, + K>{}(std::get{}>(smfmac_kernel_container)); + }); + + EXPECT_TRUE(pass); + } +}; + +template +using I = ck::Number; + +using KernelTypes = + ::testing::Types, F16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, BF16, I<8>, F32, I<4>, F32, F32, I<16>, I<16>, I<32>>, + std::tuple, F16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>, + std::tuple, BF16, I<8>, F32, I<16>, F32, F32, I<32>, I<32>, I<16>>>; + +TYPED_TEST_SUITE(TestSmfmac, KernelTypes); +TYPED_TEST(TestSmfmac, TestSmfmacFP16BF16) { this->Run(); } diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index 49782bce6e21d4a250b0d07526abf4e39ac57b9d..3e511ab5bf1b2b34746eb21f7c63dfedf242fc1a 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/utility/amd_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace wmma_op_util { @@ -140,10 +141,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif for(int ele = 0; ele < 16; ++ele) { @@ -155,10 +164,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } +#ifdef __gfx12__ + asm volatile("\ + s_wait_dscnt 0x0 \n \ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else asm volatile("\ s_waitcnt lgkmcnt(0) \n \ s_barrier \ " ::); +#endif // sync threads, similar to mma_sync // __syncthreads(); @@ -357,7 +374,8 @@ struct TestWmma a, b, c_host, a_element_op, b_element_op, c_element_op); // Act - bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); + bool is_supported = ck::is_gfx11_supported() && + ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device); if(is_supported) {